Spaces:
Paused
Paused
| static std::unique_ptr<MPSLibraryManager> mps_library_manager=nullptr; | |
| MPSLibraryManager* MPSLibraryManager::getInstance() { | |
| if(!mps_library_manager) | |
| mps_library_manager = std::unique_ptr<MPSLibraryManager>(new MPSLibraryManager()); | |
| return mps_library_manager.get(); | |
| } | |
| MPSLibraryManager::~MPSLibraryManager() {} | |
| MPSLibraryManager::MPSLibraryManager() {} | |
| bool MPSLibraryManager::hasLibrary(const std::string& name) { | |
| return _library_map.find(name) != _library_map.end(); | |
| } | |
| MPSLibrary* MPSLibraryManager::getLibrary(const std::string& library_url) { | |
| if (_library_map.find(library_url) != _library_map.end()) { | |
| return _library_map[library_url].get(); | |
| } | |
| _library_map.emplace(std::make_pair( | |
| library_url, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromUrl(library_url)))); | |
| return _library_map[library_url].get(); | |
| } | |
| MPSLibrary* MPSLibraryManager::createLibraryFromSouce(const std::string& name, | |
| const std::string& source) { | |
| NSString* ns_name = [NSString stringWithCString:name.c_str()]; | |
| if (_library_map.find(name) != _library_map.end()) { | |
| NSLog(@"Library %@ already exist.", ns_name); | |
| return nullptr; | |
| } | |
| _library_map.emplace( | |
| std::make_pair(name, std::unique_ptr<MPSLibrary>(MPSLibrary::createFromSource(source)))); | |
| return _library_map[name].get(); | |
| } | |
| MPSLibrary* MPSLibrary::createFromUrl(const std::string& library_url) { | |
| MPSLibrary* library = new MPSLibrary(); | |
| @autoreleasepool { | |
| NSError* error = nil; | |
| // load library and func | |
| NSString* utl_str = [NSString stringWithCString:library_url.c_str()]; | |
| NSURL* metal_url = [NSURL fileURLWithPath:utl_str]; | |
| library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL:metal_url | |
| error:&error]; | |
| if (library->_library == nil) { | |
| NSLog(@"Failed to find library, error %@.", error); | |
| exit(1); | |
| } | |
| } | |
| return library; | |
| } | |
| MPSLibrary* MPSLibrary::createFromSource(const std::string& sources) { | |
| MPSLibrary* library = new MPSLibrary(); | |
| @autoreleasepool { | |
| NSError* error = nil; | |
| // load library and func | |
| NSString* code_str = [NSString stringWithCString:sources.c_str()]; | |
| library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource:code_str | |
| options:nil | |
| error:&error]; | |
| if (library->_library == nil) { | |
| NSLog(@"Failed to find library, error %@.", error); | |
| exit(1); | |
| } | |
| } | |
| return library; | |
| } | |
| MPSLibrary::~MPSLibrary() { | |
| [_library release]; | |
| _library = nil; | |
| } | |
| MTLComputePipelineState_t MPSLibrary::getComputePipelineState(const std::string& function_name) { | |
| if (_pso_map.find(function_name) != _pso_map.end()) { | |
| return _pso_map[function_name]; | |
| } | |
| MTLComputePipelineState_t pso; | |
| @autoreleasepool { | |
| NSError* error = nil; | |
| // create function | |
| NSString* function_name_str = [NSString stringWithCString:function_name.c_str()]; | |
| id<MTLFunction> func = [_library newFunctionWithName:function_name_str]; | |
| if (func == nil) { | |
| NSLog(@"Failed to created pipeline state object, error %@.", error); | |
| exit(1); | |
| } | |
| // create pipeline | |
| pso = [at::mps::MPSDevice::getInstance()->device() newComputePipelineStateWithFunction:func | |
| error:&error]; | |
| _pso_map.emplace(std::make_pair(function_name, pso)); | |
| } | |
| return _pso_map[function_name]; | |
| } | |