""" Model loading and inference utilities (SAFE VERSION) ✔ Handles torch failure (DLL issue) ✔ CPU fallback ✔ Streamlit-safe caching ✔ Works even if BERT/Longformer fail """ import numpy as np import joblib import streamlit as st import contextlib # ── SAFE TORCH IMPORT ───────────────────────────── torch = None try: import torch as _torch torch = _torch except Exception: torch = None # ── CONFIG IMPORTS ──────────────────────────────── from utils.config import ( BILINGUAL_LOOKUP_PATH, SVM_PATH, MODEL_B2_PATH, MODEL_C_PATH, MODEL_D_PATH, CLINICALBERT_NAME, LONGFORMER_NAME, NUM_LABELS_FULL, NUM_LABELS_RERANKER, MAX_LENGTH_BERT, MAX_LENGTH_LONG, ) from utils.preprocessing import clean_clinical_text from utils.retriever import HierarchicalTFIDFRetriever # ── DEVICE HANDLING ─────────────────────────────── def get_device(): if torch is not None and torch.cuda.is_available(): return torch.device("cuda") return "cpu" def get_gpu_info(): if torch is None: return None if torch.cuda.is_available(): return { "name": torch.cuda.get_device_name(0), "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), "reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), "total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), } return None # ── LOOKUP ──────────────────────────────────────── @st.cache_resource(show_spinner="Loading ICD-10 lookup...") def load_bilingual_lookup(): return joblib.load(BILINGUAL_LOOKUP_PATH) # ── LABEL ENCODER ──────────────────────────────── @st.cache_resource(show_spinner="Preparing labels...") def load_label_encoder(): from sklearn.preprocessing import LabelEncoder lookup = load_bilingual_lookup() le = LabelEncoder() le.fit(sorted(lookup.keys())) return le # ── RETRIEVER ───────────────────────────────────── @st.cache_resource(show_spinner="Building TF-IDF retriever...") def load_retriever(): lookup = load_bilingual_lookup() retriever = HierarchicalTFIDFRetriever() retriever.fit(lookup) return retriever # ── SVM (MODEL A) ───────────────────────────────── @st.cache_resource(show_spinner="Loading SVM model...") def load_model_a(): """Load the TF-IDF + LinearSVC pipeline.""" import os if not os.path.exists(SVM_PATH): return None try: return joblib.load(SVM_PATH) except Exception as e: print("SVM LOAD ERROR:", e) return None def predict_svm(text, top_k=10): """Run SVM prediction and return results in the standard format.""" from scipy.special import softmax svm_pipeline = load_model_a() if svm_pipeline is None: return None le = load_label_encoder() lookup = load_bilingual_lookup() try: scores = svm_pipeline.decision_function([text])[0] probs = softmax(scores) top_idx = np.argsort(probs)[::-1][:top_k] results = [] for rank, idx in enumerate(top_idx, 1): icd_code = le.classes_[idx] info = lookup.get(icd_code, {}) results.append({ "rank": rank, "icd_code": icd_code, "confidence": float(probs[idx]), "english_description": info.get("english", "Unknown"), "chinese_description": info.get("chinese", ""), }) return results except Exception as e: print("SVM PREDICT ERROR:", e) return None # ── MODEL LOADERS ───────────────────────────────── @st.cache_resource def load_model_b2(): if torch is None: return None, None, "cpu" try: from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel device = get_device() tokenizer = AutoTokenizer.from_pretrained(MODEL_B2_PATH) base = AutoModelForSequenceClassification.from_pretrained( CLINICALBERT_NAME, num_labels=NUM_LABELS_FULL ) model = PeftModel.from_pretrained(base, MODEL_B2_PATH) if device != "cpu": model = model.to(device) model.eval() return model, tokenizer, device except Exception as e: print("BERT LOAD ERROR:", e) return None, None, "cpu" @st.cache_resource def load_model_c(): if torch is None: return None, None, "cpu" try: from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel device = get_device() tokenizer = AutoTokenizer.from_pretrained(MODEL_C_PATH) base = AutoModelForSequenceClassification.from_pretrained( LONGFORMER_NAME, num_labels=NUM_LABELS_FULL ) model = PeftModel.from_pretrained(base, MODEL_C_PATH) if device != "cpu": model = model.to(device) model.eval() return model, tokenizer, device except Exception as e: print("LONGFORMER LOAD ERROR:", e) return None, None, "cpu" @st.cache_resource def load_model_d(): if torch is None: return None, None, "cpu" try: from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel device = get_device() tokenizer = AutoTokenizer.from_pretrained(MODEL_D_PATH) base = AutoModelForSequenceClassification.from_pretrained( CLINICALBERT_NAME, num_labels=NUM_LABELS_RERANKER ) model = PeftModel.from_pretrained(base, MODEL_D_PATH) if device != "cpu": model = model.to(device) model.eval() return model, tokenizer, device except Exception as e: print("RERANKER LOAD ERROR:", e) return None, None, "cpu" # ── CORE INFERENCE ──────────────────────────────── def predict_single_label(model, tokenizer, device, text, max_length, top_k=10): if torch is None or model is None: return [] enc = tokenizer( text, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] if device != "cpu": input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask).logits probs = torch.softmax(logits, dim=-1).cpu().numpy().flatten() top_idx = np.argsort(probs)[::-1][:top_k] return [(int(i), float(probs[i])) for i in top_idx] # ── MODEL PREDICTIONS ───────────────────────────── def predict_b2(text, top_k=10): model, tokenizer, device = load_model_b2() if model is None: return None le = load_label_encoder() lookup = load_bilingual_lookup() results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_BERT, top_k) return [ { "rank": rank, "icd_code": le.classes_[idx], "confidence": prob, "english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"), "chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""), } for rank, (idx, prob) in enumerate(results, 1) ] def predict_longformer(text, top_k=10): model, tokenizer, device = load_model_c() if model is None: return None le = load_label_encoder() lookup = load_bilingual_lookup() results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_LONG, top_k) return [ { "rank": rank, "icd_code": le.classes_[idx], "confidence": prob, "english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"), "chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""), } for rank, (idx, prob) in enumerate(results, 1) ] def predict_reranker(text, top_k=10): retriever = load_retriever() model, tokenizer, device = load_model_d() if model is None: return None lookup = load_bilingual_lookup() candidates = retriever.retrieve(text, top_k=100) results = [] for code, _ in candidates: desc = lookup.get(code, {}).get("english", "") enc = tokenizer( text, desc, max_length=MAX_LENGTH_BERT, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] if device != "cpu": input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask).logits score = torch.sigmoid(logits).item() results.append((code, score)) results.sort(key=lambda x: x[1], reverse=True) final = [] for rank, (code, score) in enumerate(results[:top_k], 1): info = lookup.get(code, {}) final.append({ "rank": rank, "icd_code": code, "confidence": score, "english_description": info.get("english", "Unknown"), "chinese_description": info.get("chinese", ""), }) return final