import os import sys import yaml import torch from transformers import AutoTokenizer from Sentiment.ml.model.multitask_bert import MultiTaskBert MODEL_DIR = "saved_models" _CACHE = {"model": None, "tokenizer": None, "meta": None, "device": None} # --------------------------------------------------------------------------- # ModernBERT uses torch.compile during import; torch.compile is not supported # on Windows. Patch it to a no-op early to avoid runtime import failures. # --------------------------------------------------------------------------- if sys.platform.startswith("win") and hasattr(torch, "compile"): def _noop_compile(fn=None, *args, **kwargs): # Handles both @torch.compile and @torch.compile(...) if fn is None: def decorator(f): return f return decorator return fn torch.compile = _noop_compile def load_model(): if _CACHE["model"] is not None: return _CACHE["model"], _CACHE["tokenizer"], _CACHE["meta"], _CACHE["device"] force_cpu = os.getenv("SENTIMENT_FORCE_CPU", "").lower() in {"1", "true", "yes"} requested = os.getenv("SENTIMENT_DEVICE", "").lower() if force_cpu or requested == "cpu": device = torch.device("cpu") else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = os.path.join(MODEL_DIR, "model.pt") meta_path = os.path.join(MODEL_DIR, "meta.yaml") tokenizer_dir = os.path.join(MODEL_DIR, "tokenizer") if not os.path.exists(model_path): raise RuntimeError("model.pt not found - train the model first") if not os.path.exists(meta_path): raise RuntimeError("meta.yaml not found") if not os.path.isdir(tokenizer_dir): raise RuntimeError("tokenizer folder not found") with open(meta_path, "r", encoding="utf-8") as f: meta = yaml.safe_load(f) meta = _normalize_meta(meta) # Build on CPU first to avoid GPU OOM spikes during state_dict load model = MultiTaskBert( meta["model_name"], len(meta["tasks"]["sentiment"]["labels"]), len(meta["tasks"]["intent"]["labels"]), len(meta["tasks"]["topic"]["labels"]), init_from_pretrained=False, # VERY IMPORTANT ) state_dict = torch.load(model_path, map_location="cpu") # Backward-compat: older checkpoints used prefix "bert." and head names without suffix. if any(key.startswith("bert.") for key in state_dict.keys()): remapped = {} for key, val in state_dict.items(): if key.startswith("bert."): remapped["encoder." + key[len("bert."):]] = val elif key.startswith("sentiment"): remapped["sentiment_head" + key[len("sentiment"):]] = val elif key.startswith("intent"): remapped["intent_head" + key[len("intent"):]] = val elif key.startswith("topic"): remapped["topic_head" + key[len("topic"):]] = val else: remapped[key] = val state_dict = remapped # Drop any weights whose shape doesn't match the current architecture current_state = model.state_dict() filtered_state = {} dropped = [] for key, val in state_dict.items(): if key in current_state and current_state[key].shape != val.shape: dropped.append(key) continue filtered_state[key] = val if dropped: # If heads were dropped, the model will reinit them randomly. print( f"[load_model] dropped incompatible keys: {', '.join(dropped[:5])}" f"{' ...' if len(dropped) > 5 else ''}" ) model.load_state_dict(filtered_state, strict=False) if device.type == "cuda": try: model.to(device) except RuntimeError as e: if "out of memory" in str(e).lower(): print("[load_model] CUDA OOM; falling back to CPU") device = torch.device("cpu") model.to(device) if hasattr(torch, "cuda"): torch.cuda.empty_cache() else: raise model.eval() try: tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) except Exception: # Fallback to hub tokenizer (local dump is incomplete) tokenizer = AutoTokenizer.from_pretrained( meta["model_name"], trust_remote_code=True, ) _CACHE.update({"model": model, "tokenizer": tokenizer, "meta": meta, "device": device}) return _CACHE["model"], _CACHE["tokenizer"], _CACHE["meta"], _CACHE["device"] def _labels_to_mapping(labels): if isinstance(labels, dict): # Ensure keys are ints when possible mapping = {} for k, v in labels.items(): try: mapping[int(k)] = v except Exception: mapping[k] = v return mapping if isinstance(labels, list): return {i: v for i, v in enumerate(labels)} raise ValueError("labels must be list or dict") def _normalize_meta(meta): """ Supports two formats: 1) tasks: { sentiment: {labels: {0:..}}, ... }, max_len 2) labels: { sentiment: [..], ... }, max_length """ if meta is None: raise ValueError("meta.yaml is empty") if "tasks" in meta: tasks = meta["tasks"] norm_tasks = {} for task, cfg in tasks.items(): labels = cfg.get("labels", cfg) norm_tasks[task] = {"labels": _labels_to_mapping(labels)} meta["tasks"] = norm_tasks elif "labels" in meta: norm_tasks = {} for task, labels in meta["labels"].items(): norm_tasks[task] = {"labels": _labels_to_mapping(labels)} meta["tasks"] = norm_tasks else: raise ValueError("meta.yaml must contain 'tasks' or 'labels'") if "max_len" not in meta: if "max_length" in meta: meta["max_len"] = meta["max_length"] else: meta["max_len"] = 256 return meta