| import torch |
| from model import GLiNER |
|
|
|
|
| def save_model(current_model, path): |
| config = current_model.config |
| dict_save = {"model_weights": current_model.state_dict(), "config": config} |
| torch.save(dict_save, path) |
|
|
|
|
| def load_model(path, model_name=None, device=None): |
| dict_load = torch.load(path, map_location=torch.device('cpu')) |
| config = dict_load["config"] |
|
|
| if model_name is not None: |
| config.model_name = model_name |
|
|
| loaded_model = GLiNER(config) |
| loaded_model.load_state_dict(dict_load["model_weights"]) |
| return loaded_model.to(device) if device is not None else loaded_model |
|
|