| | import ctypes |
| | import numpy as np |
| | import os |
| |
|
| | class GPUNanoF1: |
| | def __init__(self): |
| | |
| | path = os.path.join(os.path.dirname(__file__), 'f1_kernel.so') |
| | if not os.path.exists(path): |
| | raise Exception("Le Kernel F-1 n'est pas compilé. Lancez 'sh compile.sh' d'abord.") |
| | |
| | self.lib = ctypes.CDLL(path) |
| | self.lib.launch_f1_kernel.argtypes = [ |
| | ctypes.POINTER(ctypes.c_float), |
| | ctypes.POINTER(ctypes.c_float), |
| | ctypes.POINTER(ctypes.c_float), |
| | ctypes.c_int |
| | ] |
| |
|
| | def compute(self, A, B): |
| | |
| | A = A.astype(np.float32) |
| | B = B.astype(np.float32) |
| | size = A.shape[0] |
| | C = np.zeros((size, size), dtype=np.float32) |
| | |
| | self.lib.launch_f1_kernel( |
| | A.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| | B.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| | C.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| | size |
| | ) |
| | return C |
| | |