PHI3 / app.py
johnnydang88's picture
Update app.py
fc72d25 verified
"""
Cardiology AI Assistant — Microsoft Phi-3-Mini-4k-Instruct
Hugging Face ZeroGPU Space
Includes: BERTScore F1, ROUGE-N, Semantic Similarity, Faithfulness, Answer Relevance, Context Recall
Same metric stack as the Llama-3 version — all fixes applied:
• SentenceTransformer forced to CPU (prevents stale CUDA zero-vector bug)
• ROUGE uses precision (overlap / answer_ngrams), not recall vs huge context
• Context capped at 60 sentences before embedding (prevents OOM)
• Per-metric try/except so one failure never kills the whole panel
"""
import os, gc, re, torch, warnings, pdfplumber
import numpy as np
import spaces
from collections import Counter
from typing import List, Dict
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from sentence_transformers import CrossEncoder, SentenceTransformer
import gradio as gr
warnings.filterwarnings("ignore")
PDF_PATH = "./2024ESC-compressed.pdf"
# ══════════════════════════════════════════════════════════════════════════════
# PDF LOADER
# ══════════════════════════════════════════════════════════════════════════════
def load_pdf_with_tables(path):
print(f"📂 Loading {path}...", flush=True)
docs = []
with pdfplumber.open(path) as pdf:
for i, page in enumerate(pdf.pages):
if i % 10 == 0:
print(f" Page {i}/{len(pdf.pages)}...", flush=True)
tables = page.extract_tables()
table_str = ""
if tables:
for t in tables:
if not t:
continue
table_str += "\n\n**[TABLE]**\n"
for row in t:
clean = [str(c).replace("\n", " ") if c else "" for c in row]
table_str += "| " + " | ".join(clean) + " |\n"
text = page.extract_text() or ""
docs.append(Document(page_content=text + "\n" + table_str, metadata={"page": i + 1}))
return docs
# ══════════════════════════════════════════════════════════════════════════════
# MEDCPT EMBEDDINGS (CPU)
# ══════════════════════════════════════════════════════════════════════════════
class MedCPTEmbeddings(Embeddings):
def __init__(self):
print("⚙️ Initializing MedCPT on CPU...", flush=True)
self.article_tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
self.article_model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
self.query_tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
self.query_model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
def embed_documents(self, texts):
all_embeddings = []
for i in range(0, len(texts), 16):
batch = texts[i: i + 16]
with torch.no_grad():
enc = self.article_tokenizer(
batch, max_length=512, padding=True, truncation=True, return_tensors="pt"
)
out = self.article_model(**enc)
all_embeddings.extend(out.last_hidden_state[:, 0, :].tolist())
return all_embeddings
def embed_query(self, text):
with torch.no_grad():
enc = self.query_tokenizer(
[text], max_length=512, padding=True, truncation=True, return_tensors="pt"
)
out = self.query_model(**enc)
return out.last_hidden_state[:, 0, :][0].tolist()
def free_article_encoder(self):
del self.article_model, self.article_tokenizer
gc.collect()
# ══════════════════════════════════════════════════════════════════════════════
# STARTUP
# ══════════════════════════════════════════════════════════════════════════════
print("📂 Loading PDF...", flush=True)
raw_docs = load_pdf_with_tables(PDF_PATH)
print("✂️ Splitting...", flush=True)
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
chunks = splitter.split_documents(raw_docs)
print("🧠 Building MedCPT vector store (CPU)...", flush=True)
emb = MedCPTEmbeddings()
vectorstore = FAISS.from_documents(chunks, emb)
emb.free_article_encoder()
print("✅ Vector store ready.", flush=True)
print("⚖️ Loading CrossEncoder (CPU)...", flush=True)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu")
# Explicitly load on CPU — after ZeroGPU releases the GPU, auto-device detection
# can latch onto a stale CUDA context and silently return zero vectors.
print("📐 Loading metrics SentenceTransformer (CPU)...", flush=True)
metrics_st = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
print("✅ Metrics encoder ready.", flush=True)
print("🚀 Loading Phi-3-Mini in float16 (CPU)...", flush=True)
MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=False)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
trust_remote_code=False,
)
model.eval()
print("✅ Phi-3 ready (CPU). GPU borrowed per request via ZeroGPU.", flush=True)
# ══════════════════════════════════════════════════════════════════════════════
# RERANKER
# ══════════════════════════════════════════════════════════════════════════════
def rerank_docs(query: str, docs):
scores = reranker.predict([[query, d.page_content] for d in docs])
return scores
# ══════════════════════════════════════════════════════════════════════════════
# EVALUATION METRICS
# All reference-free — uses retrieved context + query as the reference signal.
# Identical implementation to the Llama-3 version for consistency.
# ══════════════════════════════════════════════════════════════════════════════
def _sent_tokenize(text: str) -> List[str]:
"""Lightweight sentence splitter — no NLTK required."""
sents = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sents if len(s.strip()) > 10]
def _encode(texts: List[str]) -> np.ndarray:
"""
Encode on CPU explicitly.
After ZeroGPU releases the GPU, SentenceTransformer's auto-device detection
can latch onto a stale CUDA context and return zero vectors.
Forcing CPU guarantees correct, non-zero embeddings every time.
"""
return metrics_st.encode(
texts,
normalize_embeddings=True,
show_progress_bar=False,
device="cpu",
convert_to_numpy=True,
)
def _ngrams(tokens: List[str], n: int) -> Counter:
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
def rouge_n(hypothesis: str, reference: str, n: int = 1) -> float:
"""
ROUGE-N precision: fraction of answer n-grams that appear in the context.
Using precision (not recall) because the context is ~6,000+ tokens — recall
of a ~60-token answer against that pool is always ~4% even for correct answers.
"""
hyp_tokens = hypothesis.lower().split()
ref_tokens = reference.lower().split()
hyp_ng = _ngrams(hyp_tokens, n)
ref_ng = _ngrams(ref_tokens, n)
overlap = sum((hyp_ng & ref_ng).values())
denom = sum(hyp_ng.values()) # precision: denominator = answer n-grams
return round(overlap / denom, 4) if denom > 0 else 0.0
def bertscore_f1(answer: str, context_sents: List[str]) -> float:
"""
Approximate BERTScore F1 via sentence-level embeddings.
P = mean max-cosine(answer_sent → any context_sent)
R = mean max-cosine(context_sent → any answer_sent)
F1 = harmonic mean(P, R)
Uses pre-tokenised, capped context sentences to avoid encoding 100+ sentences.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
P = float(sim.max(axis=1).mean())
R = float(sim.max(axis=0).mean())
f1 = 2 * P * R / (P + R + 1e-9)
return round(max(f1, 0.0), 4)
except Exception as e:
print(f"⚠️ bertscore_f1 error: {e}", flush=True)
return 0.0
def semantic_similarity(answer: str, query: str) -> float:
"""Cosine similarity between answer embedding and query embedding."""
try:
embs = _encode([answer, query])
score = float(embs[0] @ embs[1])
return round(max(score, 0.0), 4)
except Exception as e:
print(f"⚠️ semantic_similarity error: {e}", flush=True)
return 0.0
def faithfulness(answer: str, context_sents: List[str], threshold: float = 0.35) -> float:
"""
Fraction of answer sentences whose max cosine-sim to any context sentence ≥ threshold.
Threshold = 0.35 (not 0.40) so paraphrased but grounded sentences are counted.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
max_per_ans = sim.max(axis=1)
faithful_count = int((max_per_ans >= threshold).sum())
return round(faithful_count / len(ans_sents), 4)
except Exception as e:
print(f"⚠️ faithfulness error: {e}", flush=True)
return 0.0
def answer_relevance(answer: str, query: str) -> float:
"""Does the answer actually address what was asked?"""
return semantic_similarity(answer, query)
def context_recall(answer: str, context_sents: List[str], threshold: float = 0.35) -> float:
"""
Fraction of context sentences reflected in the answer.
Mirrors RAGAS Context Recall but without ground-truth labels.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
max_per_ctx = sim.max(axis=0)
recalled_count = int((max_per_ctx >= threshold).sum())
return round(recalled_count / len(context_sents), 4)
except Exception as e:
print(f"⚠️ context_recall error: {e}", flush=True)
return 0.0
def compute_all_metrics(query: str, answer: str, context: str) -> Dict[str, float]:
"""
Tokenise context once, cap at 60 sentences (top-ranked chunks come first),
then run all embedding-based metrics against that capped list.
ROUGE uses the raw context string (pure token overlap, no matrices).
"""
ctx_sents_all = _sent_tokenize(context)
ctx_sents = ctx_sents_all[:60]
print(f"📐 Metrics: answer={len(_sent_tokenize(answer))} sents, "
f"ctx={len(ctx_sents)}/{len(ctx_sents_all)} sents", flush=True)
return {
"BERTScore F1": bertscore_f1(answer, ctx_sents),
"ROUGE-1": rouge_n(answer, context, n=1),
"ROUGE-2": rouge_n(answer, context, n=2),
"Semantic Similarity": semantic_similarity(answer, query),
"Faithfulness": faithfulness(answer, ctx_sents),
"Answer Relevance": answer_relevance(answer, query),
"Context Recall": context_recall(answer, ctx_sents),
}
# ── Display helpers ───────────────────────────────────────────────────────────
_METRIC_DESCRIPTIONS = {
"BERTScore F1": "Sentence-level semantic overlap F1 between answer and top context sentences.",
"ROUGE-1": "Fraction of answer unigrams found in retrieved context (precision).",
"ROUGE-2": "Fraction of answer bigrams found in retrieved context (precision).",
"Semantic Similarity": "Cosine similarity between answer and question embeddings.",
"Faithfulness": "Fraction of answer sentences semantically supported by the retrieved context.",
"Answer Relevance": "How directly the answer addresses the original question.",
"Context Recall": "Fraction of top context sentences reflected in the answer.",
}
_THRESHOLDS = {
# (warn_below, ok_below, good_above)
"BERTScore F1": (0.50, 0.65, 0.80),
"ROUGE-1": (0.15, 0.30, 0.45),
"ROUGE-2": (0.05, 0.15, 0.25),
"Semantic Similarity": (0.40, 0.60, 0.75),
"Faithfulness": (0.50, 0.70, 0.85),
"Answer Relevance": (0.40, 0.60, 0.75),
"Context Recall": (0.15, 0.30, 0.50),
}
def _colour(name: str, value: float) -> str:
warn, ok, good = _THRESHOLDS.get(name, (0.3, 0.6, 0.8))
if value >= good: return "🟢"
if value >= ok: return "🟡"
return "🔴"
def _bar(value: float, width: int = 20) -> str:
filled = int(round(value * width))
return "█" * filled + "░" * (width - filled)
def format_metrics_markdown(metrics: Dict[str, float]) -> str:
lines = ["## 📊 Evaluation Metrics\n"]
lines.append(
"> Metrics are **reference-free** and computed against the retrieved context "
"and original query — no labelled ground truth required.\n"
)
lines.append("| Metric | Score | Bar | Status | Notes |")
lines.append("|--------|------:|-----|--------|-------|")
for name, value in metrics.items():
pct = f"{value:.2%}"
bar = f"`{_bar(value)}`"
icon = _colour(name, value)
desc = _METRIC_DESCRIPTIONS.get(name, "")
lines.append(f"| **{name}** | {pct} | {bar} | {icon} | {desc} |")
lines.append("\n**Colour key:** 🟢 Good · 🟡 Acceptable · 🔴 Needs attention")
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# GPU FUNCTION
# ══════════════════════════════════════════════════════════════════════════════
@spaces.GPU(duration=120)
def llm_generate(messages: list) -> str:
print("🔥 GPU acquired, running generation...", flush=True)
model.to("cuda")
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=300,
do_sample=False,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(
output_ids[0][input_ids.shape[1]:], skip_special_tokens=True
).strip()
model.to("cpu")
torch.cuda.empty_cache()
print("✅ Generation complete.", flush=True)
return response
# ══════════════════════════════════════════════════════════════════════════════
# RAG PIPELINE
# ══════════════════════════════════════════════════════════════════════════════
def process_query_stream(query: str):
# ── Step 1: retrieval ────────────────────────────────────────────────────
yield (
"⏳ **Status:** 🔍 Retrieving relevant documents from VectorDB...\n\n---\n",
""
)
retrieved = vectorstore.similarity_search(query, k=60)
unique, seen = [], set()
for doc in retrieved:
pg = doc.metadata.get("page")
if pg not in seen:
unique.append(doc)
seen.add(pg)
# ── Step 2: rerank ───────────────────────────────────────────────────────
yield (
"⏳ **Status:** 📊 Reranking with CrossEncoder (CPU)...\n\n---\n",
""
)
scores = rerank_docs(query, unique)
scored = sorted(zip(unique, scores), key=lambda x: x[1], reverse=True)
top_docs = [d for d, _ in scored[:5]]
context = "\n\n".join(d.page_content for d in top_docs)
source_pages = ", ".join(str(d.metadata.get("page")) for d in top_docs)
# ── Step 3: generate ─────────────────────────────────────────────────────
yield (
"⏳ **Status:** 🧠 Synthesizing with Phi-3 (ZeroGPU H200)...\n\n---\n",
""
)
messages = [
{
"role": "system",
"content": (
"You are a strict clinical assistant for the ESC Guidelines. Answer based ONLY on the context. "
"CRITICAL RULES:\n"
"1. Define scoring systems strictly.\n"
"2. State clearly if a treatment is Class III.\n"
"3. Do not hallucinate acronym meanings.\n"
"4. If the prompt is not medical, output: 'This model only answers to medical questions.'"
),
},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{query}"},
]
answer = llm_generate(messages)
answer_md = f"### ⚕️ Answer\n\n{answer}\n\n📄 **Source Pages:** {source_pages}\n"
# ── Step 4: metrics ──────────────────────────────────────────────────────
yield (
answer_md,
"⏳ **Status:** 📐 Computing evaluation metrics (CPU)...\n"
)
metrics = compute_all_metrics(query, answer, context)
metrics_md = format_metrics_markdown(metrics)
yield (answer_md, metrics_md)
# ══════════════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════════════
def gradio_wrapper(query):
if not query or not query.strip():
yield "⚠️ Please enter a valid question.", ""
return
yield from process_query_stream(query)
phi_theme = gr.themes.Soft(
primary_hue="teal",
secondary_hue="blue",
neutral_hue="slate",
).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_700",
)
with gr.Blocks(theme=phi_theme, title="Cardiology AI Assistant") as demo:
# ── Header ───────────────────────────────────────────────────────────────
gr.Markdown("# ⚕️ Cardiology AI Assistant (ESC 2024)")
gr.Markdown("### ⚡ Powered by Microsoft Phi-3-Mini · ZeroGPU H200")
gr.Markdown(
"Ask questions based on the **2024 ESC Medical Guidelines**. "
"Uses RAG with MedCPT embeddings, Cross-Encoder reranking, Phi-3 generation, "
"and **live evaluation metrics**."
)
# ── Input ────────────────────────────────────────────────────────────────
with gr.Row():
with gr.Column(scale=4):
input_text = gr.Textbox(
label="Your Clinical Question",
placeholder="e.g., What are the class I recommendations for anticoagulation in AF?",
lines=3,
)
with gr.Column(scale=1, min_width=160):
submit_btn = gr.Button("🔍 Analyze Guidelines", variant="primary", size="lg")
# ── Examples ─────────────────────────────────────────────────────────────
gr.Examples(
examples=[
"What are the class I recommendations for anticoagulation in AF?",
"Summarize the treatment algorithm for chronic heart failure.",
"What is the target LDL-C for very high-risk patients?",
],
inputs=input_text,
label="Example Questions",
)
gr.Markdown("---")
# ── Answer output (full width) ────────────────────────────────────────────
answer_output = gr.Markdown(
label="Assistant Response",
value="*Your answer will appear here after submission.*",
)
gr.Markdown("---")
# ── Metrics output (full width, below answer) ─────────────────────────────
metrics_output = gr.Markdown(
label="Evaluation Metrics",
value="*Metrics will appear here once the answer is generated.*",
)
gr.Markdown("---")
# ── Metric legend ─────────────────────────────────────────────────────────
with gr.Accordion("ℹ️ About the Evaluation Metrics", open=False):
gr.Markdown("""
### How each metric is computed
| Metric | Method | Interpretation |
|--------|--------|---------------|
| **BERTScore F1** | Sentence-level cosine-sim F1 between answer sentences and top-60 context sentences using `all-MiniLM-L6-v2` (forced CPU) | Measures how semantically similar the answer is to the source context |
| **ROUGE-1** | **Precision**: fraction of answer unigrams that appear in the retrieved context | Are the words the model used actually in the retrieved passages? |
| **ROUGE-2** | **Precision**: fraction of answer bigrams that appear in the retrieved context | Are the phrases the model used actually in the retrieved passages? |
| **Semantic Similarity** | Cosine similarity of full answer ↔ question embeddings | Does the answer embed in the same semantic space as the question? |
| **Faithfulness** | Fraction of answer sentences with cosine-sim ≥ 0.35 to any context sentence | Are answer claims grounded in retrieved text? |
| **Answer Relevance** | Cosine similarity of answer ↔ question embeddings | How directly does the answer respond to the question? |
| **Context Recall** | Fraction of top-60 context sentences with cosine-sim ≥ 0.35 to any answer sentence | How much of the retrieved evidence is used in the answer? |
> **Why precision for ROUGE?** The retrieved context is ~6,000 tokens; a correct ~60-token answer
> has only ~4% unigram *recall* against that pool — even if every word came from the context.
> Precision asks the right question: *"Did the model use words that actually appear in the retrieved passages?"*
> **All metrics are reference-free** — they use the retrieved context and original query as the
> reference signal, so no annotated ground-truth is needed.
""")
# ── Wire up ───────────────────────────────────────────────────────────────
submit_btn.click(
fn=gradio_wrapper,
inputs=input_text,
outputs=[answer_output, metrics_output],
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)