Spaces:
Paused
Paused
implemented model saving
Browse files
model.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
| 3 |
from typing import List, Dict, Optional
|
| 4 |
from pathlib import Path
|
| 5 |
import json
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
@@ -254,6 +255,21 @@ class T5Model(LabelStudioMLBase):
|
|
| 254 |
|
| 255 |
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
# Switch back to eval mode
|
| 258 |
model.eval()
|
| 259 |
|
|
|
|
| 3 |
from typing import List, Dict, Optional
|
| 4 |
from pathlib import Path
|
| 5 |
import json
|
| 6 |
+
from datetime import datetime
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 255 |
|
| 256 |
logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
|
| 257 |
|
| 258 |
+
# Save the model
|
| 259 |
+
try:
|
| 260 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 261 |
+
model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
|
| 262 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
save_path = model_dir / f"model_{timestamp}"
|
| 265 |
+
logger.info(f"Saving model to {save_path}")
|
| 266 |
+
model.save_pretrained(save_path)
|
| 267 |
+
logger.info(f"Model successfully saved to {save_path}")
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Failed to save model: {str(e)}")
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
# Switch back to eval mode
|
| 274 |
model.eval()
|
| 275 |
|