b2u commited on
Commit
ed47222
·
1 Parent(s): 5b6ee0c

implemented model saving

Browse files
Files changed (1) hide show
  1. model.py +16 -0
model.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  from typing import List, Dict, Optional
4
  from pathlib import Path
5
  import json
 
6
 
7
  import torch
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@@ -254,6 +255,21 @@ class T5Model(LabelStudioMLBase):
254
 
255
  logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # Switch back to eval mode
258
  model.eval()
259
 
 
3
  from typing import List, Dict, Optional
4
  from pathlib import Path
5
  import json
6
+ from datetime import datetime
7
 
8
  import torch
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
255
 
256
  logger.info(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {loss.item():.4f}")
257
 
258
+ # Save the model
259
+ try:
260
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
261
+ model_dir = Path(os.getenv('MODEL_DIR', '/data/models'))
262
+ model_dir.mkdir(parents=True, exist_ok=True)
263
+
264
+ save_path = model_dir / f"model_{timestamp}"
265
+ logger.info(f"Saving model to {save_path}")
266
+ model.save_pretrained(save_path)
267
+ logger.info(f"Model successfully saved to {save_path}")
268
+
269
+ except Exception as e:
270
+ logger.error(f"Failed to save model: {str(e)}")
271
+ raise
272
+
273
  # Switch back to eval mode
274
  model.eval()
275