AhmedRabie01 commited on
Commit
2ef0ae4
·
verified ·
1 Parent(s): ee67128

Create loader.py

Browse files
Files changed (1) hide show
  1. Sentiment/ml/model/loader.py +173 -0
Sentiment/ml/model/loader.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import yaml
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+ from Sentiment.ml.model.multitask_bert import MultiTaskBert
7
+
8
+ MODEL_DIR = "saved_models"
9
+ _CACHE = {"model": None, "tokenizer": None, "meta": None, "device": None}
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # ModernBERT uses torch.compile during import; torch.compile is not supported
13
+ # on Windows. Patch it to a no-op early to avoid runtime import failures.
14
+ # ---------------------------------------------------------------------------
15
+ if sys.platform.startswith("win") and hasattr(torch, "compile"):
16
+ def _noop_compile(fn=None, *args, **kwargs):
17
+ # Handles both @torch.compile and @torch.compile(...)
18
+ if fn is None:
19
+ def decorator(f):
20
+ return f
21
+ return decorator
22
+ return fn
23
+ torch.compile = _noop_compile
24
+
25
+
26
+ def load_model():
27
+ if _CACHE["model"] is not None:
28
+ return _CACHE["model"], _CACHE["tokenizer"], _CACHE["meta"], _CACHE["device"]
29
+
30
+ force_cpu = os.getenv("SENTIMENT_FORCE_CPU", "").lower() in {"1", "true", "yes"}
31
+ requested = os.getenv("SENTIMENT_DEVICE", "").lower()
32
+ if force_cpu or requested == "cpu":
33
+ device = torch.device("cpu")
34
+ else:
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ model_path = os.path.join(MODEL_DIR, "model.pt")
38
+ meta_path = os.path.join(MODEL_DIR, "meta.yaml")
39
+ tokenizer_dir = os.path.join(MODEL_DIR, "tokenizer")
40
+
41
+ if not os.path.exists(model_path):
42
+ raise RuntimeError("model.pt not found - train the model first")
43
+
44
+ if not os.path.exists(meta_path):
45
+ raise RuntimeError("meta.yaml not found")
46
+
47
+ if not os.path.isdir(tokenizer_dir):
48
+ raise RuntimeError("tokenizer folder not found")
49
+
50
+ with open(meta_path, "r", encoding="utf-8") as f:
51
+ meta = yaml.safe_load(f)
52
+ meta = _normalize_meta(meta)
53
+
54
+ # Build on CPU first to avoid GPU OOM spikes during state_dict load
55
+ model = MultiTaskBert(
56
+ meta["model_name"],
57
+ len(meta["tasks"]["sentiment"]["labels"]),
58
+ len(meta["tasks"]["intent"]["labels"]),
59
+ len(meta["tasks"]["topic"]["labels"]),
60
+ init_from_pretrained=False, # VERY IMPORTANT
61
+ )
62
+
63
+ state_dict = torch.load(model_path, map_location="cpu")
64
+
65
+ # Backward-compat: older checkpoints used prefix "bert." and head names without suffix.
66
+ if any(key.startswith("bert.") for key in state_dict.keys()):
67
+ remapped = {}
68
+ for key, val in state_dict.items():
69
+ if key.startswith("bert."):
70
+ remapped["encoder." + key[len("bert."):]] = val
71
+ elif key.startswith("sentiment"):
72
+ remapped["sentiment_head" + key[len("sentiment"):]] = val
73
+ elif key.startswith("intent"):
74
+ remapped["intent_head" + key[len("intent"):]] = val
75
+ elif key.startswith("topic"):
76
+ remapped["topic_head" + key[len("topic"):]] = val
77
+ else:
78
+ remapped[key] = val
79
+ state_dict = remapped
80
+
81
+ # Drop any weights whose shape doesn't match the current architecture
82
+ current_state = model.state_dict()
83
+ filtered_state = {}
84
+ dropped = []
85
+ for key, val in state_dict.items():
86
+ if key in current_state and current_state[key].shape != val.shape:
87
+ dropped.append(key)
88
+ continue
89
+ filtered_state[key] = val
90
+
91
+ if dropped:
92
+ # If heads were dropped, the model will reinit them randomly.
93
+ print(
94
+ f"[load_model] dropped incompatible keys: {', '.join(dropped[:5])}"
95
+ f"{' ...' if len(dropped) > 5 else ''}"
96
+ )
97
+
98
+ model.load_state_dict(filtered_state, strict=False)
99
+
100
+ if device.type == "cuda":
101
+ try:
102
+ model.to(device)
103
+ except RuntimeError as e:
104
+ if "out of memory" in str(e).lower():
105
+ print("[load_model] CUDA OOM; falling back to CPU")
106
+ device = torch.device("cpu")
107
+ model.to(device)
108
+ if hasattr(torch, "cuda"):
109
+ torch.cuda.empty_cache()
110
+ else:
111
+ raise
112
+
113
+ model.eval()
114
+
115
+ try:
116
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
117
+ except Exception:
118
+ # Fallback to hub tokenizer (local dump is incomplete)
119
+ tokenizer = AutoTokenizer.from_pretrained(
120
+ meta["model_name"],
121
+ trust_remote_code=True,
122
+ )
123
+
124
+ _CACHE.update({"model": model, "tokenizer": tokenizer, "meta": meta, "device": device})
125
+ return _CACHE["model"], _CACHE["tokenizer"], _CACHE["meta"], _CACHE["device"]
126
+
127
+
128
+ def _labels_to_mapping(labels):
129
+ if isinstance(labels, dict):
130
+ # Ensure keys are ints when possible
131
+ mapping = {}
132
+ for k, v in labels.items():
133
+ try:
134
+ mapping[int(k)] = v
135
+ except Exception:
136
+ mapping[k] = v
137
+ return mapping
138
+ if isinstance(labels, list):
139
+ return {i: v for i, v in enumerate(labels)}
140
+ raise ValueError("labels must be list or dict")
141
+
142
+
143
+ def _normalize_meta(meta):
144
+ """
145
+ Supports two formats:
146
+ 1) tasks: { sentiment: {labels: {0:..}}, ... }, max_len
147
+ 2) labels: { sentiment: [..], ... }, max_length
148
+ """
149
+ if meta is None:
150
+ raise ValueError("meta.yaml is empty")
151
+
152
+ if "tasks" in meta:
153
+ tasks = meta["tasks"]
154
+ norm_tasks = {}
155
+ for task, cfg in tasks.items():
156
+ labels = cfg.get("labels", cfg)
157
+ norm_tasks[task] = {"labels": _labels_to_mapping(labels)}
158
+ meta["tasks"] = norm_tasks
159
+ elif "labels" in meta:
160
+ norm_tasks = {}
161
+ for task, labels in meta["labels"].items():
162
+ norm_tasks[task] = {"labels": _labels_to_mapping(labels)}
163
+ meta["tasks"] = norm_tasks
164
+ else:
165
+ raise ValueError("meta.yaml must contain 'tasks' or 'labels'")
166
+
167
+ if "max_len" not in meta:
168
+ if "max_length" in meta:
169
+ meta["max_len"] = meta["max_length"]
170
+ else:
171
+ meta["max_len"] = 256
172
+
173
+ return meta