Update model.py
Browse files
model.py
CHANGED
|
@@ -52,6 +52,7 @@ class ViTRecognitionModel(nn.Module):
|
|
| 52 |
|
| 53 |
def load_model(model_path, device='cpu'):
|
| 54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
|
|
|
| 55 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
| 56 |
model.to(device)
|
| 57 |
model.eval()
|
|
|
|
| 52 |
|
| 53 |
def load_model(model_path, device='cpu'):
|
| 54 |
model = ViTRecognitionModel(vocab_size=vocab_size, hidden_dim=768, max_length=20)
|
| 55 |
+
# Set weights_only=True to address the FutureWarning
|
| 56 |
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
| 57 |
model.to(device)
|
| 58 |
model.eval()
|