File size: 10,878 Bytes
76f9669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
// SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

#include <stdio.h>
#include <stdlib.h>
#include <cstring>
#include "loader.h"

#define DXCORE_MAX_PATH 260

#if defined(_WIN32)
#include "windows.h"
#define _getAddr GetProcAddress
#define _Handle HMODULE
static const size_t sysrootName64_length = (sizeof("System32") - 1);
static const char* sysrootName64 = "System32";
static const size_t libcudaName64_length = (sizeof("\\nvcuda64.dll") - 1);
static const char* libcudaName64 = "\\nvcuda64.dll";
static const size_t sysrootNameX86_length = (sizeof("SysWOW64") - 1);
static const char* sysrootNameX86 = "SysWOW64";
static const size_t libcudaNameX86_length = (sizeof("\\nvcuda32.dll") - 1);
static const char* libcudaNameX86 = "\\nvcuda32.dll";
static size_t sysrootName_length = NULL;
static const char* sysrootName = NULL;

#else
#include <dlfcn.h>
#include <unistd.h>
#define _getAddr dlsym
#define _Handle void*
static const size_t libcudaNameLinux_length = (sizeof("/libcuda.so.1.1") - 1);
static const char* libcudaNameLinux = "/libcuda.so.1.1";
#endif
static size_t libcudaName_length = 0;
static const char* libcudaName = NULL;

struct dxcore_enumAdapters2;
struct dxcore_queryAdapterInfo;

typedef int (*pfnDxcoreEnumAdapters2)(const dxcore_enumAdapters2 *pParams);
typedef int (*pfnDxcoreQueryAdapterInfo)(const dxcore_queryAdapterInfo *pParams);

struct dxcore_lib {
    _Handle hDxcoreLib;
    pfnDxcoreEnumAdapters2 pDxcoreEnumAdapters2;
    pfnDxcoreQueryAdapterInfo pDxcoreQueryAdapterInfo;
};

struct dxcore_luid
{
    unsigned int lowPart;
    int highPart;
};

struct dxcore_adapterInfo
{
    unsigned int              hAdapter;
    struct dxcore_luid        AdapterLuid;
    unsigned int              NumOfSources;
    unsigned int              bPresentMoveRegionsPreferred;
};

struct dxcore_enumAdapters2
{
    unsigned int                   NumAdapters;
    struct dxcore_adapterInfo     *pAdapters;
};

enum dxcore_kmtqueryAdapterInfoType
{
    DXCORE_QUERYDRIVERVERSION = 13,
    DXCORE_QUERYREGISTRY = 48,
};

enum dxcore_queryregistry_type {
    DXCORE_QUERYREGISTRY_DRIVERSTOREPATH = 2,
};

enum dxcore_queryregistry_status {
    DXCORE_QUERYREGISTRY_STATUS_SUCCESS = 0,
    DXCORE_QUERYREGISTRY_STATUS_BUFFER_OVERFLOW = 1,
    DXCORE_QUERYREGISTRY_STATUS_FAIL = 2,
};

struct dxcore_queryregistry_info {
    enum dxcore_queryregistry_type        QueryType;
    unsigned int                          QueryFlags;
    wchar_t                               ValueName[DXCORE_MAX_PATH];
    unsigned int                          ValueType;
    unsigned int                          PhysicalAdapterIndex;
    unsigned int                          OutputValueSize;
    enum dxcore_queryregistry_status      Status;
    union {
        unsigned long long                    OutputQword;
        wchar_t                               Output;
    };
};

struct dxcore_queryAdapterInfo
{
    unsigned int                           hAdapter;
    enum dxcore_kmtqueryAdapterInfoType    Type;
    void                                   *pPrivateDriverData;
    unsigned int                           PrivateDriverDataSize;
};

static int dxcore_query_adapter_info_helper(struct dxcore_lib* pLib,
                                            unsigned int hAdapter,
                                            enum dxcore_kmtqueryAdapterInfoType type,
                                            void* pPrivateDriverDate,
                                            unsigned int privateDriverDataSize)
{
    struct dxcore_queryAdapterInfo queryAdapterInfo = {};

    queryAdapterInfo.hAdapter = hAdapter;
    queryAdapterInfo.Type = type;
    queryAdapterInfo.pPrivateDriverData = pPrivateDriverDate;
    queryAdapterInfo.PrivateDriverDataSize = privateDriverDataSize;

    return pLib->pDxcoreQueryAdapterInfo(&queryAdapterInfo);
}

static int dxcore_query_adapter_wddm_version(struct dxcore_lib* pLib, unsigned int hAdapter, unsigned int* version)
{
        return dxcore_query_adapter_info_helper(pLib,
                                                hAdapter,
                                                DXCORE_QUERYDRIVERVERSION,
                                                (void*)version,
                                                (unsigned int)sizeof(*version));
}

static int dxcore_query_adapter_driverstore_path(struct dxcore_lib* pLib, unsigned int hAdapter, char** ppDriverStorePath)
{
    struct dxcore_queryregistry_info params = {};
    struct dxcore_queryregistry_info* pValue = NULL;
    wchar_t* pOutput;
    size_t outputSizeInBytes;
    size_t outputSize;

    // 1. Fetch output size
    params.QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH;

    if (dxcore_query_adapter_info_helper(pLib,
                                         hAdapter,
                                         DXCORE_QUERYREGISTRY,
                                         (void*)&params,
                                         (unsigned int)sizeof(struct dxcore_queryregistry_info)))
    {
        return (-1);
    }

    if (params.OutputValueSize > DXCORE_MAX_PATH * sizeof(wchar_t)) {
        return (-1);
    }

    outputSizeInBytes = (size_t)params.OutputValueSize;
    outputSize = outputSizeInBytes / sizeof(wchar_t);

    // 2. Retrieve output
    pValue = (struct dxcore_queryregistry_info*)calloc(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes + sizeof(wchar_t), 1);
    if (!pValue) {
        return (-1);
    }

    pValue->QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH;
    pValue->OutputValueSize = (unsigned int)outputSizeInBytes;

    if (dxcore_query_adapter_info_helper(pLib,
                                         hAdapter,
                                         DXCORE_QUERYREGISTRY,
                                         (void*)pValue,
                                         (unsigned int)(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes)))
    {
        free(pValue);
        return (-1);
    }
    pOutput = (wchar_t*)(&pValue->Output);

    // Make sure no matter what happened the wchar_t string is null terminated
    pOutput[outputSize] = L'\0';

    // Convert the output into a regular c string
    *ppDriverStorePath = (char*)calloc(outputSize + 1, sizeof(char));
    if (!*ppDriverStorePath) {
        free(pValue);
        return (-1);
    }
    wcstombs(*ppDriverStorePath, pOutput, outputSize);

    free(pValue);

    return 0;
}

static char* replaceSystemPath(char* path)
{
    char *replacedPath = (char*)calloc(DXCORE_MAX_PATH + 1, sizeof(char));

#if defined(_WIN32)
    wchar_t *systemPath = (wchar_t*)calloc(DXCORE_MAX_PATH + 1, sizeof(wchar_t));
    // Get system root path
    if (GetSystemDirectoryW(systemPath, DXCORE_MAX_PATH) == 0) {
        free(replacedPath);
        free(systemPath);
        return NULL;
    }
    wcstombs(replacedPath, systemPath, DXCORE_MAX_PATH);
    free(systemPath);

    // Replace the /SystemRoot/ part of the registry-obtained path with
    // the actual system root path from above
    char* sysrootPath = strstr(path, sysrootName);
    strncat(replacedPath, sysrootPath + sysrootName_length, DXCORE_MAX_PATH - strlen(replacedPath));
#else
    strncat(replacedPath, path, DXCORE_MAX_PATH);
#endif

    // Append nvcuda dll
    if (libcudaName_length < DXCORE_MAX_PATH - strlen(replacedPath)) {
        strncat(replacedPath, libcudaName, libcudaName_length);
    }
    else {
        strncat(replacedPath, libcudaName, DXCORE_MAX_PATH - strlen(replacedPath));
    }

    return replacedPath;
}

static int dxcore_check_adapter(struct dxcore_lib *pLib, char *libPath, struct dxcore_adapterInfo *pAdapterInfo)
{
    unsigned int wddmVersion = 0;
    char* driverStorePath = NULL;

    if (dxcore_query_adapter_wddm_version(pLib, pAdapterInfo->hAdapter, &wddmVersion)) {
        return 1;
    }

    if (wddmVersion < 2500) {
        return 1;
    }

    if (dxcore_query_adapter_driverstore_path(pLib, pAdapterInfo->hAdapter, &driverStorePath)) {
        return 1;
    }

    // Replace with valid path
    char* replacedPath = replaceSystemPath(driverStorePath);
    if (!replacedPath) {
        free(driverStorePath);
        free(replacedPath);
        return 1;
    }

    // Does file exist?
#if defined(_WIN32)
    if (GetFileAttributes(replacedPath) == INVALID_FILE_ATTRIBUTES) {
        free(driverStorePath);
        free(replacedPath);
        return 1;
    }
#else
    if (access(replacedPath, F_OK) < 0) {
        free(driverStorePath);
        free(replacedPath);
        return 1;
    }
#endif

    memcpy(libPath, replacedPath, DXCORE_MAX_PATH);
    free(driverStorePath);
    free(replacedPath);

    return 0;
}

static int dxcore_enum_adapters(struct dxcore_lib *pLib, char *libPath)
{
    struct dxcore_enumAdapters2 params = {0};
    unsigned int adapterIndex = 0;

    if (pLib->pDxcoreEnumAdapters2(&params)) {
        return 1;
    }
    params.pAdapters = (dxcore_adapterInfo*)calloc(params.NumAdapters, sizeof(struct dxcore_adapterInfo));
    if (pLib->pDxcoreEnumAdapters2(&params)) {
        free(params.pAdapters);
        return 1;
    }

    for (adapterIndex = 0; adapterIndex < params.NumAdapters; adapterIndex++) {
        if (!dxcore_check_adapter(pLib, libPath, &params.pAdapters[adapterIndex])) {
            free(params.pAdapters);
            return 0;
        }
    }

    free(params.pAdapters);
    return 1;
}

int getCUDALibraryPath(char *libPath, bool isBit64)
{
    struct dxcore_lib lib = {0};

    if (!libPath) {
        return 1;
    }

    // Configure paths based on app's bit configuration
#if defined(_WIN32)
    if (isBit64) {
        sysrootName_length = sysrootName64_length;
        sysrootName = sysrootName64;
        libcudaName_length = libcudaName64_length;
        libcudaName = libcudaName64;
    }
    else {
        sysrootName_length = sysrootNameX86_length;
        sysrootName = sysrootNameX86;
        libcudaName_length = libcudaNameX86_length;
        libcudaName = libcudaNameX86;
    }
#else
    libcudaName_length = libcudaNameLinux_length;
    libcudaName = libcudaNameLinux;
#endif

#if defined(_WIN32)
    lib.hDxcoreLib = LoadLibraryExW(L"gdi32.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
#else
    lib.hDxcoreLib = dlopen("libdxcore.so", RTLD_LAZY);
#endif
    if (!lib.hDxcoreLib) {
        return 1;
    }

    lib.pDxcoreEnumAdapters2 = (pfnDxcoreEnumAdapters2)_getAddr(lib.hDxcoreLib, "D3DKMTEnumAdapters2");
    if (!lib.pDxcoreEnumAdapters2) {
        return 1;
    }
    lib.pDxcoreQueryAdapterInfo = (pfnDxcoreQueryAdapterInfo)_getAddr(lib.hDxcoreLib, "D3DKMTQueryAdapterInfo");
    if (!lib.pDxcoreQueryAdapterInfo) {
        return 1;
    }

    if (dxcore_enum_adapters(&lib, libPath)) {
        return 1;
    }
    return 0;
}