medi-llm / app /utils /inference_utils.py
GitHub Actions
Sync from GitHub @ 9da9a816321d04c1d3f005fdd07f81dd793478cd
68a3a7d
import os
import sys
import torch
import yaml
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer
from torchvision import transforms
from huggingface_hub import hf_hub_download
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(ROOT_DIR))
from src.multimodal_model import MediLLMModel
from app.utils.gradcam_utils import register_hooks, generate_gradcam
# --------------------
# Runtime / Hub config
# --------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Map modes -> filenames in HF model repo
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "Preetham22/medi-llm-weights")
_raw_rev = os.getenv("HF_WEIGHTS_REV", None)
HF_WEIGHTS_REV = _raw_rev if (_raw_rev and _raw_rev.strip()) else None # optional (commit/tag/branch), can be None
FILENAMES = {
"text": "medi_llm_state_dict_text.pth",
"image": "medi_llm_state_dict_image.pth",
"multimodal": "medi_llm_state_dict_multimodal.pth",
}
def have_internet():
try:
import socket
socket.create_connection(("huggingface.co", 443), timeout=3).close()
return True
except Exception:
return False
def resolve_weights_path(mode: str) -> str:
"""Download (or reuse cached) weights for the given mode from HF Hub."""
if mode not in FILENAMES:
raise ValueError(f"Unknown mode '{mode}'. Expected one of {list(FILENAMES)}.")
filename = FILENAMES[mode]
# 1) Prefer a file already present in Space rep
local_path = ROOT_DIR / filename
if local_path.exists():
return str(local_path)
# 2) If no local file and no internet, bail early
if not have_internet():
raise RuntimeError(
f"❌ Internet is disabled and weights are not present locally.\n"
f" Upload '{filename}' to this Space or enable Network access."
)
# 3) Otherwise, download from Hub
try:
return hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=filename,
revision=HF_WEIGHTS_REV, # can be None -> default branch
repo_type="model", # change to "dataset" if needed
local_dir=str(ROOT_DIR), # Keep a copy in repo dir
local_dir_use_symlinks=False, # avoid symlink weirdness
token=None, # For public repo
)
except Exception as e:
raise RuntimeError(
f"Failed to fetch weights '{filename}' from repo '{HF_MODEL_REPO}'. "
f"Either enable Network access for this Space or commit the file locally. "
f"Original error: {e}"
)
# ----------------------
# Labels / preprocessing
# ----------------------
inv_map = {0: "low", 1: "medium", 2: "high"}
# Tokenizer and image transform
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
image_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor()
])
# ----------------------
# Model load
# ----------------------
def _safe_torch_load(path: str, map_location: torch.device):
"""
Prefer weights_only=True (newer Pytorch), but fall back if not supported.
"""
try:
return torch.load(path, map_location=map_location, weights_only=True) # PyTorch >= 2.2/2.3
except TypeError:
return torch.load(path, map_location=map_location)
def load_model(mode: str, config_path: str = str(Path("config/config.yaml").resolve())):
"""
Load MediLLMModel for the given mode and populate weights from HF Hub.
Expects config/config.yaml with keys per mode (dropout, hidden_dim).
"""
with open(config_path, "r") as f:
cfg_all = yaml.safe_load(f)
if mode not in cfg_all:
raise KeyError(f"Mode '{mode}' not found in {config_path}. Keys: {list(cfg_all.keys())}")
config = cfg_all[mode]
# Build model
model = MediLLMModel(
mode=mode,
dropout=config["dropout"],
hidden_dim=config["hidden_dim"]
)
# Download weights & load
weights_path = resolve_weights_path(mode)
state = _safe_torch_load(weights_path, map_location=DEVICE)
# Sometimes checkpoints save as {'state_dict': ...}
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
try:
model.load_state_dict(state) # strict by default
except RuntimeError as e:
# allow non-strict if minor mismatches (buffer names)
try:
model.load_state_dict(state, strict=False)
print(f"⚠️ Loaded with strict=False due to: {e}")
except Exception:
raise
model.to(DEVICE)
model.eval()
return model
# -----------------------
# Attention rollout utils
# -----------------------
def attention_rollout(attentions, last_k=4, residual_alpha=0.5):
"""
attentions_tuple: tuple/list of layer attentions; each is (B,H,S,S)
last_k: only roll back through the last k layers (keeps contrast)
residual_alpha: how much identity to add before normalizing (preserve token self-info)
returns: [B, S, S] rollout matrix, or None if input is invalid
"""
if attentions is None:
return None
if isinstance(attentions, (list, tuple)) and len(attentions) == 0:
return None
first = attentions[0]
if first is None or first.ndim != 4:
return None # expect [B, H, S, S]
B, H, S, _ = first.shape
eye = torch.eye(S, device=first.device).unsqueeze(0).expand(B, S, S) # [B, S, S]
L = len(attentions)
if last_k is None:
last_k = L
if last_k <= 0:
# No layers selected -> return identity (no propagation)
return eye.clone()
start = max(0, L - last_k)
A = None
for layer in range(start, L):
a = attentions[layer]
if a is None or a.ndim != 4 or a.shape[0] != B or a.shape[-1] != S:
# Skip malformed layer
continue
a = a.mean(dim=1) # [B, S, S] (avg heads)
a = a + float(residual_alpha) * eye
a = a / (a.sum(dim=-1, keepdim=True) + 1e-12) # row-normalize
A = a if A is None else torch.bmm(A, a)
# if we never multiplied like when all layers skipped, fall back to identity
return A if A is not None else eye.clone() # [B,S,S]
def merge_wordpieces(tokens, scores):
merged_tokens, merged_scores = [], []
cur_tok, cur_scores = "", []
for t, s in zip(tokens, scores):
if t.startswith("##"):
cur_tok += t[2:]
cur_scores.append(s)
else:
if cur_tok:
merged_tokens.append(cur_tok)
merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
cur_tok, cur_scores = t, [s]
if cur_tok:
merged_tokens.append(cur_tok)
merged_scores.append(sum(cur_scores) / max(1, len(cur_scores)))
return merged_tokens, merged_scores
def _normalize_for_display_wordlevel(attn_scores, normalize_mode="visual", temperature=0.30):
"""
Convert raw *word-level* token scores into:
- probabilistic mode: probabilities that sum to 1.0 (100%), with labels like "0.237 | 23.7% (contrib)"
- visual mode: min-max + gamma scaling (contrast, not sum-to-100), with labels like "0.68 | visual score"
Returns:
attn_final: np.ndarray of floats in [0, 1] for color scale
labels: list[str] per token (tooltip text; first number stays up front for your color_map bucketing)
"""
attn_array = np.array(attn_scores, dtype=float)
if normalize_mode == "probabilistic":
# ---- percentage view that sums up to 100% ----
attn_array = np.maximum(attn_array, 0.0)
if attn_array.max() > 0:
attn_array = attn_array / (attn_array.max() + 1e-12) # scale to [0, 1] for stability
# sharpen (lower temp => peakier)
attn_array = np.power(attn_array + 1e-12, 1.0 / max(1e-6, float(temperature)))
prob = attn_array / (attn_array.sum() + 1e-12)
percent = prob * 100.0
# keep prob (0..1) for color scale; label with % contrib
labels = [f"{prob[i]:.3f} | {percent[i]:.1f}% (contrib)" for i in range(len(prob))]
return prob, labels
else:
# ---- visual: min-max + gamma (contrast, not sum-to-100) ---
if attn_array.max() > attn_array.min():
attn_array0 = (attn_array - attn_array.min()) / (attn_array.max() - attn_array.min() + 1e-8)
attn_array0 = np.clip(np.power(attn_array0, 0.75), 0.1, 1.0)
else:
attn_array0 = np.zeros_like(attn_array)
labels = [f"{attn_array0[i]:.2f} | visual score" for i in range(len(attn_array0))]
return attn_array0, labels
# ------------------
# Prediction
# ------------------
def predict(
model,
mode,
emr_text=None,
image=None,
normalize_mode="visual",
need_token_vis=False,
use_rollout=False
):
"""
normalize_mode: "visual" (min-max + gamma boost) or "probabilistic" (softmax)
need_token_vis: request/compute token-level attentions (Doctor mode + text/multimodal)
use_rollout: use attention rollout across layers
"""
input_ids = attention_mask = img_tensor = None
cam_image = None
highlighted_tokens = None
top5 = []
if mode in ["text", "multimodal"] and emr_text:
text_tokens = tokenizer(
emr_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128,
)
input_ids = text_tokens["input_ids"].to(DEVICE)
attention_mask = text_tokens["attention_mask"].to(DEVICE)
if mode in ["image", "multimodal"] and image:
img_tensor = image_transform(image).unsqueeze(0).to(DEVICE)
# Only Register hooks for Grad-CAM if needed
if mode in ["image", "multimodal"]:
activations, gradients, fwd_handle, bwd_handle = register_hooks(model)
model.zero_grad()
# === Forward ===
# Only enable attentions when planning to visualize them
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
image=img_tensor,
output_attentions=bool(need_token_vis and (mode in ["text", "multimodal"])),
return_raw_attentions=bool(use_rollout and need_token_vis)
)
logits = outputs["logits"]
if logits.numel() == 0:
raise ValueError("Model returned empty logits. Check input format.")
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs.squeeze()[pred].item()
# === Grad-CAM ===
if mode in ["image", "multimodal"]:
# Enable gradients only for Grad-CAM
logits[0, pred].backward(retain_graph=True)
cam_image = generate_gradcam(image, activations, gradients)
fwd_handle.remove()
bwd_handle.remove()
# === Token-level attention ===
if need_token_vis and (mode in ["text", "multimodal"]):
token_attn_scores = None
if use_rollout and outputs.get("raw_attentions") is not None:
# partial rollout
# roll: [B, S, S]; roll[b, 0, :] is CLS-to-all tokens for that batch item
roll = attention_rollout(outputs["raw_attentions"], last_k=4, residual_alpha=0.5) # [B,S,S] # (S, S)
if roll is not None:
# roll: [B, S, S]; pick CLS row (index 0)
cls_to_tokens = roll[0, 0].detach().cpu().numpy().tolist() # CLS row
token_attn_scores = cls_to_tokens
elif outputs.get("token_attentions") is not None:
token_attn_scores = outputs["token_attentions"].squeeze().tolist()
if token_attn_scores is not None:
# Filter out specials/pad + aligh to wordpieces
ids = input_ids[0].tolist()
amask = attention_mask[0].tolist() if attention_mask is not None else [1] * len(ids)
wp_all = tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False)
special_ids = set(tokenizer.all_special_ids)
keep_idx = [i for i, (tid, m) in enumerate(zip(ids, amask)) if (tid not in special_ids) and (m == 1)]
wp_tokens = [wp_all[i] for i in keep_idx]
wp_scores = [token_attn_scores[i] if i < len(token_attn_scores) else 0.0 for i in keep_idx]
# Merge wordpieces into words
word_tokens, attn_scores = merge_wordpieces(wp_tokens, wp_scores)
# Build Top-5 (probabilistic normalization for ranking)
_probs_for_rank, _ = _normalize_for_display_wordlevel(
attn_scores, normalize_mode="probabilistic", temperature=0.30
)
pairs = list(zip(word_tokens, _probs_for_rank))
pairs.sort(key=lambda x: x[1], reverse=True)
top5 = [(tok, float(p * 100.0)) for tok, p in pairs[:5]]
# Final display (probabilistic or visual)
attn_final, labels = _normalize_for_display_wordlevel(
attn_scores,
normalize_mode=normalize_mode,
temperature=0.30,
)
highlighted_tokens = [(tok, labels[i]) for i, tok in enumerate(word_tokens)]
print("🧪 Normalization Mode Received:", normalize_mode)
if highlighted_tokens:
print("🟣 Highlighted tokens sample:", highlighted_tokens[:5])
else:
print("🟣 No highlighted tokens (no text or attentions unavailable).")
return inv_map[pred], cam_image, highlighted_tokens, confidence, probs.tolist(), top5