Vaibhavi53's picture
Upload models.py
d324d15 verified
Raw
History Blame Contribute Delete
10.2 kB
"""
Model loading and inference utilities (SAFE VERSION)
βœ” Handles torch failure (DLL issue)
βœ” CPU fallback
βœ” Streamlit-safe caching
βœ” Works even if BERT/Longformer fail
"""
import numpy as np
import joblib
import streamlit as st
import contextlib
# ── SAFE TORCH IMPORT ─────────────────────────────
torch = None
try:
import torch as _torch
torch = _torch
except Exception:
torch = None
# ── CONFIG IMPORTS ────────────────────────────────
from utils.config import (
BILINGUAL_LOOKUP_PATH, SVM_PATH, MODEL_B2_PATH, MODEL_C_PATH, MODEL_D_PATH,
CLINICALBERT_NAME, LONGFORMER_NAME,
NUM_LABELS_FULL, NUM_LABELS_RERANKER,
MAX_LENGTH_BERT, MAX_LENGTH_LONG,
)
from utils.preprocessing import clean_clinical_text
from utils.retriever import HierarchicalTFIDFRetriever
# ── DEVICE HANDLING ───────────────────────────────
def get_device():
if torch is not None and torch.cuda.is_available():
return torch.device("cuda")
return "cpu"
def get_gpu_info():
if torch is None:
return None
if torch.cuda.is_available():
return {
"name": torch.cuda.get_device_name(0),
"allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
"reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2),
"total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2),
}
return None
# ── LOOKUP ────────────────────────────────────────
@st.cache_resource(show_spinner="Loading ICD-10 lookup...")
def load_bilingual_lookup():
return joblib.load(BILINGUAL_LOOKUP_PATH)
# ── LABEL ENCODER ────────────────────────────────
@st.cache_resource(show_spinner="Preparing labels...")
def load_label_encoder():
from sklearn.preprocessing import LabelEncoder
lookup = load_bilingual_lookup()
le = LabelEncoder()
le.fit(sorted(lookup.keys()))
return le
# ── RETRIEVER ─────────────────────────────────────
@st.cache_resource(show_spinner="Building TF-IDF retriever...")
def load_retriever():
lookup = load_bilingual_lookup()
retriever = HierarchicalTFIDFRetriever()
retriever.fit(lookup)
return retriever
# ── SVM (MODEL A) ─────────────────────────────────
@st.cache_resource(show_spinner="Loading SVM model...")
def load_model_a():
"""Load the TF-IDF + LinearSVC pipeline."""
import os
if not os.path.exists(SVM_PATH):
return None
try:
return joblib.load(SVM_PATH)
except Exception as e:
print("SVM LOAD ERROR:", e)
return None
def predict_svm(text, top_k=10):
"""Run SVM prediction and return results in the standard format."""
from scipy.special import softmax
svm_pipeline = load_model_a()
if svm_pipeline is None:
return None
le = load_label_encoder()
lookup = load_bilingual_lookup()
try:
scores = svm_pipeline.decision_function([text])[0]
probs = softmax(scores)
top_idx = np.argsort(probs)[::-1][:top_k]
results = []
for rank, idx in enumerate(top_idx, 1):
icd_code = le.classes_[idx]
info = lookup.get(icd_code, {})
results.append({
"rank": rank,
"icd_code": icd_code,
"confidence": float(probs[idx]),
"english_description": info.get("english", "Unknown"),
"chinese_description": info.get("chinese", ""),
})
return results
except Exception as e:
print("SVM PREDICT ERROR:", e)
return None
# ── MODEL LOADERS ─────────────────────────────────
@st.cache_resource
def load_model_b2():
if torch is None:
return None, None, "cpu"
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
device = get_device()
tokenizer = AutoTokenizer.from_pretrained(MODEL_B2_PATH)
base = AutoModelForSequenceClassification.from_pretrained(
CLINICALBERT_NAME, num_labels=NUM_LABELS_FULL
)
model = PeftModel.from_pretrained(base, MODEL_B2_PATH)
if device != "cpu":
model = model.to(device)
model.eval()
return model, tokenizer, device
except Exception as e:
print("BERT LOAD ERROR:", e)
return None, None, "cpu"
@st.cache_resource
def load_model_c():
if torch is None:
return None, None, "cpu"
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
device = get_device()
tokenizer = AutoTokenizer.from_pretrained(MODEL_C_PATH)
base = AutoModelForSequenceClassification.from_pretrained(
LONGFORMER_NAME, num_labels=NUM_LABELS_FULL
)
model = PeftModel.from_pretrained(base, MODEL_C_PATH)
if device != "cpu":
model = model.to(device)
model.eval()
return model, tokenizer, device
except Exception as e:
print("LONGFORMER LOAD ERROR:", e)
return None, None, "cpu"
@st.cache_resource
def load_model_d():
if torch is None:
return None, None, "cpu"
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
device = get_device()
tokenizer = AutoTokenizer.from_pretrained(MODEL_D_PATH)
base = AutoModelForSequenceClassification.from_pretrained(
CLINICALBERT_NAME, num_labels=NUM_LABELS_RERANKER
)
model = PeftModel.from_pretrained(base, MODEL_D_PATH)
if device != "cpu":
model = model.to(device)
model.eval()
return model, tokenizer, device
except Exception as e:
print("RERANKER LOAD ERROR:", e)
return None, None, "cpu"
# ── CORE INFERENCE ────────────────────────────────
def predict_single_label(model, tokenizer, device, text, max_length, top_k=10):
if torch is None or model is None:
return []
enc = tokenizer(
text,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
if device != "cpu":
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy().flatten()
top_idx = np.argsort(probs)[::-1][:top_k]
return [(int(i), float(probs[i])) for i in top_idx]
# ── MODEL PREDICTIONS ─────────────────────────────
def predict_b2(text, top_k=10):
model, tokenizer, device = load_model_b2()
if model is None:
return None
le = load_label_encoder()
lookup = load_bilingual_lookup()
results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_BERT, top_k)
return [
{
"rank": rank,
"icd_code": le.classes_[idx],
"confidence": prob,
"english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"),
"chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""),
}
for rank, (idx, prob) in enumerate(results, 1)
]
def predict_longformer(text, top_k=10):
model, tokenizer, device = load_model_c()
if model is None:
return None
le = load_label_encoder()
lookup = load_bilingual_lookup()
results = predict_single_label(model, tokenizer, device, text, MAX_LENGTH_LONG, top_k)
return [
{
"rank": rank,
"icd_code": le.classes_[idx],
"confidence": prob,
"english_description": lookup.get(le.classes_[idx], {}).get("english", "Unknown"),
"chinese_description": lookup.get(le.classes_[idx], {}).get("chinese", ""),
}
for rank, (idx, prob) in enumerate(results, 1)
]
def predict_reranker(text, top_k=10):
retriever = load_retriever()
model, tokenizer, device = load_model_d()
if model is None:
return None
lookup = load_bilingual_lookup()
candidates = retriever.retrieve(text, top_k=100)
results = []
for code, _ in candidates:
desc = lookup.get(code, {}).get("english", "")
enc = tokenizer(
text, desc,
max_length=MAX_LENGTH_BERT,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
if device != "cpu":
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
score = torch.sigmoid(logits).item()
results.append((code, score))
results.sort(key=lambda x: x[1], reverse=True)
final = []
for rank, (code, score) in enumerate(results[:top_k], 1):
info = lookup.get(code, {})
final.append({
"rank": rank,
"icd_code": code,
"confidence": score,
"english_description": info.get("english", "Unknown"),
"chinese_description": info.get("chinese", ""),
})
return final