| """ |
| Model loading and inference utilities (SAFE VERSION) |
| |
| β Handles torch failure (DLL issue) |
| β CPU fallback |
| β Streamlit-safe caching |
| β Works even if BERT/Longformer fail |
| """ |
|
|
| import numpy as np |
| import joblib |
| import streamlit as st |
| import contextlib |
|
|
| |
| torch = None |
| try: |
| import torch as _torch |
| torch = _torch |
| except Exception: |
| torch = None |
|
|
|
|
| |
| from utils.config import ( |
| BILINGUAL_LOOKUP_PATH, SVM_PATH, MODEL_B2_PATH, MODEL_C_PATH, MODEL_D_PATH, |
| CLINICALBERT_NAME, LONGFORMER_NAME, |
| NUM_LABELS_FULL, NUM_LABELS_RERANKER, |
| MAX_LENGTH_BERT, MAX_LENGTH_LONG, |
| ) |
|
|
| from utils.preprocessing import clean_clinical_text |
| from utils.retriever import HierarchicalTFIDFRetriever |
|
|
|
|
| |
| def get_device(): |
| if torch is not None and torch.cuda.is_available(): |
| return torch.device("cuda") |
| return "cpu" |
|
|
|
|
| def get_gpu_info(): |
| if torch is None: |
| return None |
|
|
| if torch.cuda.is_available(): |
| return { |
| "name": torch.cuda.get_device_name(0), |
| "allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), |
| "reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2), |
| "total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2), |
| } |
| return None |
|
|
|
|
| |
| @st.cache_resource(show_spinner="Loading ICD-10 lookup...") |
| def load_bilingual_lookup(): |
| return joblib.load(BILINGUAL_LOOKUP_PATH) |
|
|
|
|
| |
| @st.cache_resource(show_spinner="Preparing labels...") |
| def load_label_encoder(): |
| from sklearn.preprocessing import LabelEncoder |
| lookup = load_bilingual_lookup() |
| le = LabelEncoder() |
| le.fit(sorted(lookup.keys())) |
| return le |
|
|
|
|
| |
| @st.cache_resource(show_spinner="Building TF-IDF retriever...") |
| def load_retriever(): |
| lookup = load_bilingual_lookup() |
| retriever = HierarchicalTFIDFRetriever() |
| retriever.fit(lookup) |
| return retriever |
|
|
|
|
| |
| @st.cache_resource(show_spinner="Loading SVM model...") |
| def load_model_a(): |
| """Load the TF-IDF + LinearSVC pipeline.""" |
| import os |
| if not os.path.exists(SVM_PATH): |
| return None |
| try: |
| return joblib.load(SVM_PATH) |
| except Exception as e: |
| print("SVM LOAD ERROR:", e) |
| return None |
|
|
|
|
| def predict_svm(text, top_k=10): |
| """Run SVM prediction and return results in the standard format.""" |
| from scipy.special import softmax |
|
|
| svm_pipeline = load_model_a() |
| if svm_pipeline is None: |
| return None |
|
|
| le = load_label_encoder() |
| lookup = load_bilingual_lookup() |
|
|
| try: |
| scores = svm_pipeline.decision_function([text])[0] |
| probs = softmax(scores) |
| top_idx = np.argsort(probs)[::-1][:top_k] |
|
|
| results = [] |
| for rank, idx in enumerate(top_idx, 1): |
| icd_code = le.classes_[idx] |
| info = lookup.get(icd_code, {}) |
| results.append({ |
| "rank": rank, |
| "icd_code": icd_code, |
| "confidence": float(probs[idx]), |
| "english_description": info.get("english", "Unknown"), |
| "chinese_description": info.get("chinese", ""), |
| }) |
| return results |
| except Exception as e: |
| print("SVM PREDICT ERROR:", e) |
| return None |
|
|
|
|
| |
| @st.cache_resource |
| def load_model_b2(): |
| if torch is None: |
| return None, None, "cpu" |
|
|
| try: |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from peft import PeftModel |
|
|
| device = get_device() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_B2_PATH) |
| base = AutoModelForSequenceClassification.from_pretrained( |
| CLINICALBERT_NAME, num_labels=NUM_LABELS_FULL |
| ) |
| model = PeftModel.from_pretrained(base, MODEL_B2_PATH) |
|
|
| if device != "cpu": |
| model = model.to(device) |
|
|
| model.eval() |
| return model, tokenizer, device |
|
|
| except Exception as e: |
| print("BERT LOAD ERROR:", e) |
| return None, None, "cpu" |
|
|
|
|
| @st.cache_resource |
| def load_model_c(): |
| if torch is None: |
| return None, None, "cpu" |
|
|
| try: |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from peft import PeftModel |
|
|
| device = get_device() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_C_PATH) |
| base = AutoModelForSequenceClassification.from_pretrained( |
| LONGFORMER_NAME, num_labels=NUM_LABELS_FULL |
| ) |
| model = PeftModel.from_pretrained(base, MODEL_C_PATH) |
|
|
| if device != "cpu": |
| model = model.to(device) |
|
|
| model.eval() |
| return model, tokenizer, device |
|
|
| except Exception as e: |
| print("LONGFORMER LOAD ERROR:", e) |
| return None, None, "cpu" |
|
|
|
|
| @st.cache_resource |
| def load_model_d(): |
| if torch is None: |
| return None, None, "cpu" |
|
|
| try: |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from peft import PeftModel |
|
|
| device = get_device() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_D_PATH) |
| base = AutoModelForSequenceClassification.from_pretrained( |
| CLINICALBERT_NAME, num_labels=NUM_LABELS_RERANKER |
| ) |
| model = PeftModel.from_pretrained(base, MODEL_D_PATH) |
|
|
| if device != "cpu": |
| model = model.to(device) |
|
|
| model.eval() |
| return model, tokenizer, device |
|
|
| except Exception as e: |
| print("RERANKER LOAD ERROR:", e) |
| return None, None, "cpu" |
|
|
|
|
| |
| def predict_single_label(model, tokenizer, device, text, max_length, top_k=10): |
|
|
| if torch is None or model is None: |
| return [] |
|
|
| enc = tokenizer( |
| text, |
| max_length=max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ) |
|
|
| input_ids = enc["input_ids"] |
| attention_mask = enc["attention_mask"] |
|
|
| if device != "cpu": |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| with torch.no_grad(): |
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits |
|
|
| probs = torch.softmax(logits, dim=-1).cpu().numpy().flatten() |
| top_idx = np.argsort(probs)[::-1][:top_k] |
|
|
| return [(int(i), float(probs[i])) for i in top_idx] |
|
|
|
|
| |
| def predict_b2(text, top_k=10): |
|
|
| model, tokenizer, device = load_model_b2() |
| if model is None: |
| return None |
|
|
| le = load_label_encoder() |
| lookup = load_bilingual_lookup() |
|
|
| results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_BERT, top_k) |
|
|
| return [ |
| { |
| "rank": rank, |
| "icd_code": le.classes_[idx], |
| "confidence": prob, |
| "english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"), |
| "chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""), |
| } |
| for rank, (idx, prob) in enumerate(results, 1) |
| ] |
|
|
|
|
| def predict_longformer(text, top_k=10): |
|
|
| model, tokenizer, device = load_model_c() |
| if model is None: |
| return None |
|
|
| le = load_label_encoder() |
| lookup = load_bilingual_lookup() |
|
|
| results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_LONG, top_k) |
|
|
| return [ |
| { |
| "rank": rank, |
| "icd_code": le.classes_[idx], |
| "confidence": prob, |
| "english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"), |
| "chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""), |
| } |
| for rank, (idx, prob) in enumerate(results, 1) |
| ] |
|
|
|
|
| def predict_reranker(text, top_k=10): |
|
|
| retriever = load_retriever() |
| model, tokenizer, device = load_model_d() |
|
|
| if model is None: |
| return None |
|
|
| lookup = load_bilingual_lookup() |
| candidates = retriever.retrieve(text, top_k=100) |
|
|
| results = [] |
|
|
| for code, _ in candidates: |
| desc = lookup.get(code, {}).get("english", "") |
|
|
| enc = tokenizer( |
| text, desc, |
| max_length=MAX_LENGTH_BERT, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ) |
|
|
| input_ids = enc["input_ids"] |
| attention_mask = enc["attention_mask"] |
|
|
| if device != "cpu": |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| with torch.no_grad(): |
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits |
|
|
| score = torch.sigmoid(logits).item() |
|
|
| results.append((code, score)) |
|
|
| results.sort(key=lambda x: x[1], reverse=True) |
|
|
| final = [] |
| for rank, (code, score) in enumerate(results[:top_k], 1): |
| info = lookup.get(code, {}) |
| final.append({ |
| "rank": rank, |
| "icd_code": code, |
| "confidence": score, |
| "english_description": info.get("english", "Unknown"), |
| "chinese_description": info.get("chinese", ""), |
| }) |
| return final |