Nihal2000 commited on
Commit
3589789
·
verified ·
1 Parent(s): 90e868f

Update src/model_manager.py

Browse files
Files changed (1) hide show
  1. src/model_manager.py +32 -12
src/model_manager.py CHANGED
@@ -31,52 +31,71 @@ class AutomotiveSLMConfig:
31
 
32
  class ModelManager:
33
  def __init__(self, models_path: str):
 
 
34
  self.models_path = models_path
35
  self.cache = {}
36
  os.makedirs(self.models_path, exist_ok=True)
37
 
38
  def get_available_models(self) -> List[str]:
 
 
39
  files = []
40
  for f in os.listdir(self.models_path):
 
 
 
41
  ext = os.path.splitext(f)[1].lower()
42
  if ext in [".pt", ".pth", ".onnx"]:
43
  files.append(f)
44
  return sorted(files)
45
 
46
  def _load_config(self, checkpoint_path: str) -> AutomotiveSLMConfig:
47
- # Prefer assets/config.json if present
48
- assets_root = os.path.dirname(self.models_path)
 
 
49
  cfg_path = os.path.join(assets_root, "config.json")
50
- if os.path.exists(cfg_path):
51
  with open(cfg_path, "r") as f:
52
  cfg = json.load(f)
53
  return AutomotiveSLMConfig(**cfg)
54
- # else try checkpoint
55
- ckpt = torch.load(checkpoint_path, map_location="cpu")
56
- if isinstance(ckpt, dict) and "config" in ckpt:
57
- return AutomotiveSLMConfig(**ckpt["config"])
 
 
 
58
  return AutomotiveSLMConfig()
59
 
60
  def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]:
 
 
 
61
  if model_filename in self.cache:
62
  return self.cache[model_filename]
 
63
  model_path = os.path.join(self.models_path, model_filename)
 
 
64
 
65
- # tokenizer (GPT-2 per your training)
66
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
67
  if tokenizer.pad_token is None:
68
  tokenizer.pad_token = tokenizer.eos_token
69
 
70
  ext = os.path.splitext(model_filename)[1].lower()
 
 
71
  if ext in [".pt", ".pth"]:
72
- config = self._load_config(model_path)
73
  from src.model_architecture import AutomotiveSLM
74
- checkpoint = torch.load(model_path, map_location="cpu")
75
  model = AutomotiveSLM(config)
76
- model.load_state_dict(checkpoint["model_state_dict"])
 
77
  model.eval()
78
  elif ext == ".onnx":
79
- config = self._load_config(model_path)
80
  providers = ["CPUExecutionProvider"]
81
  so = ort.SessionOptions()
82
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -86,3 +105,4 @@ class ModelManager:
86
 
87
  self.cache[model_filename] = (model, tokenizer, config)
88
  return model, tokenizer, config
 
 
31
 
32
  class ModelManager:
33
  def __init__(self, models_path: str):
34
+ if not isinstance(models_path, str) or not models_path:
35
+ raise ValueError(f"models_path must be a non-empty string, got: {models_path!r}")
36
  self.models_path = models_path
37
  self.cache = {}
38
  os.makedirs(self.models_path, exist_ok=True)
39
 
40
  def get_available_models(self) -> List[str]:
41
+ if not os.path.isdir(self.models_path):
42
+ return []
43
  files = []
44
  for f in os.listdir(self.models_path):
45
+ path = os.path.join(self.models_path, f)
46
+ if not os.path.isfile(path):
47
+ continue
48
  ext = os.path.splitext(f)[1].lower()
49
  if ext in [".pt", ".pth", ".onnx"]:
50
  files.append(f)
51
  return sorted(files)
52
 
53
  def _load_config(self, checkpoint_path: str) -> AutomotiveSLMConfig:
54
+ # Derive assets root safely
55
+ if not isinstance(checkpoint_path, str):
56
+ raise ValueError(f"checkpoint_path must be a string, got: {checkpoint_path!r}")
57
+ assets_root = os.path.dirname(self.models_path) # assets
58
  cfg_path = os.path.join(assets_root, "config.json")
59
+ if isinstance(cfg_path, str) and os.path.exists(cfg_path):
60
  with open(cfg_path, "r") as f:
61
  cfg = json.load(f)
62
  return AutomotiveSLMConfig(**cfg)
63
+ # Fall back to reading from checkpoint if it’s a torch file
64
+ ext = os.path.splitext(checkpoint_path)[1].lower()
65
+ if ext in [".pt", ".pth"] and os.path.exists(checkpoint_path):
66
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
67
+ if isinstance(ckpt, dict) and "config" in ckpt:
68
+ return AutomotiveSLMConfig(**ckpt["config"])
69
+ # Final fallback
70
  return AutomotiveSLMConfig()
71
 
72
  def load_model(self, model_filename: str) -> Tuple[Any, Any, AutomotiveSLMConfig]:
73
+ if not isinstance(model_filename, str) or not model_filename:
74
+ raise ValueError(f"model_filename must be a non-empty string, got: {model_filename!r}")
75
+
76
  if model_filename in self.cache:
77
  return self.cache[model_filename]
78
+
79
  model_path = os.path.join(self.models_path, model_filename)
80
+ if not os.path.isfile(model_path):
81
+ raise FileNotFoundError(f"Model file not found: {model_path}")
82
 
83
+ # tokenizer
84
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
85
  if tokenizer.pad_token is None:
86
  tokenizer.pad_token = tokenizer.eos_token
87
 
88
  ext = os.path.splitext(model_filename)[1].lower()
89
+ config = self._load_config(model_path)
90
+
91
  if ext in [".pt", ".pth"]:
 
92
  from src.model_architecture import AutomotiveSLM
93
+ checkpoint = torch.load(model_path, map_location="cpu")
94
  model = AutomotiveSLM(config)
95
+ state = checkpoint.get("model_state_dict", checkpoint)
96
+ model.load_state_dict(state, strict=True)
97
  model.eval()
98
  elif ext == ".onnx":
 
99
  providers = ["CPUExecutionProvider"]
100
  so = ort.SessionOptions()
101
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
 
105
 
106
  self.cache[model_filename] = (model, tokenizer, config)
107
  return model, tokenizer, config
108
+