Spaces:
Sleeping
Sleeping
| """ | |
| Cardiology AI Assistant — Microsoft Phi-3-Mini-4k-Instruct | |
| Hugging Face ZeroGPU Space | |
| Includes: BERTScore F1, ROUGE-N, Semantic Similarity, Faithfulness, Answer Relevance, Context Recall | |
| Same metric stack as the Llama-3 version — all fixes applied: | |
| • SentenceTransformer forced to CPU (prevents stale CUDA zero-vector bug) | |
| • ROUGE uses precision (overlap / answer_ngrams), not recall vs huge context | |
| • Context capped at 60 sentences before embedding (prevents OOM) | |
| • Per-metric try/except so one failure never kills the whole panel | |
| """ | |
| import os, gc, re, torch, warnings, pdfplumber | |
| import numpy as np | |
| import spaces | |
| from collections import Counter | |
| from typing import List, Dict | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.embeddings import Embeddings | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| import gradio as gr | |
| warnings.filterwarnings("ignore") | |
| PDF_PATH = "./2024ESC-compressed.pdf" | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # PDF LOADER | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def load_pdf_with_tables(path): | |
| print(f"📂 Loading {path}...", flush=True) | |
| docs = [] | |
| with pdfplumber.open(path) as pdf: | |
| for i, page in enumerate(pdf.pages): | |
| if i % 10 == 0: | |
| print(f" Page {i}/{len(pdf.pages)}...", flush=True) | |
| tables = page.extract_tables() | |
| table_str = "" | |
| if tables: | |
| for t in tables: | |
| if not t: | |
| continue | |
| table_str += "\n\n**[TABLE]**\n" | |
| for row in t: | |
| clean = [str(c).replace("\n", " ") if c else "" for c in row] | |
| table_str += "| " + " | ".join(clean) + " |\n" | |
| text = page.extract_text() or "" | |
| docs.append(Document(page_content=text + "\n" + table_str, metadata={"page": i + 1})) | |
| return docs | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # MEDCPT EMBEDDINGS (CPU) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class MedCPTEmbeddings(Embeddings): | |
| def __init__(self): | |
| print("⚙️ Initializing MedCPT on CPU...", flush=True) | |
| self.article_tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder") | |
| self.article_model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder") | |
| self.query_tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| self.query_model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| def embed_documents(self, texts): | |
| all_embeddings = [] | |
| for i in range(0, len(texts), 16): | |
| batch = texts[i: i + 16] | |
| with torch.no_grad(): | |
| enc = self.article_tokenizer( | |
| batch, max_length=512, padding=True, truncation=True, return_tensors="pt" | |
| ) | |
| out = self.article_model(**enc) | |
| all_embeddings.extend(out.last_hidden_state[:, 0, :].tolist()) | |
| return all_embeddings | |
| def embed_query(self, text): | |
| with torch.no_grad(): | |
| enc = self.query_tokenizer( | |
| [text], max_length=512, padding=True, truncation=True, return_tensors="pt" | |
| ) | |
| out = self.query_model(**enc) | |
| return out.last_hidden_state[:, 0, :][0].tolist() | |
| def free_article_encoder(self): | |
| del self.article_model, self.article_tokenizer | |
| gc.collect() | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # STARTUP | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| print("📂 Loading PDF...", flush=True) | |
| raw_docs = load_pdf_with_tables(PDF_PATH) | |
| print("✂️ Splitting...", flush=True) | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64) | |
| chunks = splitter.split_documents(raw_docs) | |
| print("🧠 Building MedCPT vector store (CPU)...", flush=True) | |
| emb = MedCPTEmbeddings() | |
| vectorstore = FAISS.from_documents(chunks, emb) | |
| emb.free_article_encoder() | |
| print("✅ Vector store ready.", flush=True) | |
| print("⚖️ Loading CrossEncoder (CPU)...", flush=True) | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu") | |
| # Explicitly load on CPU — after ZeroGPU releases the GPU, auto-device detection | |
| # can latch onto a stale CUDA context and silently return zero vectors. | |
| print("📐 Loading metrics SentenceTransformer (CPU)...", flush=True) | |
| metrics_st = SentenceTransformer("all-MiniLM-L6-v2", device="cpu") | |
| print("✅ Metrics encoder ready.", flush=True) | |
| print("🚀 Loading Phi-3-Mini in float16 (CPU)...", flush=True) | |
| MODEL_ID = "microsoft/Phi-3-mini-4k-instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=False) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.padding_side = "left" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=False, | |
| ) | |
| model.eval() | |
| print("✅ Phi-3 ready (CPU). GPU borrowed per request via ZeroGPU.", flush=True) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # RERANKER | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def rerank_docs(query: str, docs): | |
| scores = reranker.predict([[query, d.page_content] for d in docs]) | |
| return scores | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # EVALUATION METRICS | |
| # All reference-free — uses retrieved context + query as the reference signal. | |
| # Identical implementation to the Llama-3 version for consistency. | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def _sent_tokenize(text: str) -> List[str]: | |
| """Lightweight sentence splitter — no NLTK required.""" | |
| sents = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| return [s.strip() for s in sents if len(s.strip()) > 10] | |
| def _encode(texts: List[str]) -> np.ndarray: | |
| """ | |
| Encode on CPU explicitly. | |
| After ZeroGPU releases the GPU, SentenceTransformer's auto-device detection | |
| can latch onto a stale CUDA context and return zero vectors. | |
| Forcing CPU guarantees correct, non-zero embeddings every time. | |
| """ | |
| return metrics_st.encode( | |
| texts, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| device="cpu", | |
| convert_to_numpy=True, | |
| ) | |
| def _ngrams(tokens: List[str], n: int) -> Counter: | |
| return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)) | |
| def rouge_n(hypothesis: str, reference: str, n: int = 1) -> float: | |
| """ | |
| ROUGE-N precision: fraction of answer n-grams that appear in the context. | |
| Using precision (not recall) because the context is ~6,000+ tokens — recall | |
| of a ~60-token answer against that pool is always ~4% even for correct answers. | |
| """ | |
| hyp_tokens = hypothesis.lower().split() | |
| ref_tokens = reference.lower().split() | |
| hyp_ng = _ngrams(hyp_tokens, n) | |
| ref_ng = _ngrams(ref_tokens, n) | |
| overlap = sum((hyp_ng & ref_ng).values()) | |
| denom = sum(hyp_ng.values()) # precision: denominator = answer n-grams | |
| return round(overlap / denom, 4) if denom > 0 else 0.0 | |
| def bertscore_f1(answer: str, context_sents: List[str]) -> float: | |
| """ | |
| Approximate BERTScore F1 via sentence-level embeddings. | |
| P = mean max-cosine(answer_sent → any context_sent) | |
| R = mean max-cosine(context_sent → any answer_sent) | |
| F1 = harmonic mean(P, R) | |
| Uses pre-tokenised, capped context sentences to avoid encoding 100+ sentences. | |
| """ | |
| ans_sents = _sent_tokenize(answer) | |
| if not ans_sents or not context_sents: | |
| return 0.0 | |
| try: | |
| a_embs = _encode(ans_sents) | |
| c_embs = _encode(context_sents) | |
| sim = a_embs @ c_embs.T | |
| P = float(sim.max(axis=1).mean()) | |
| R = float(sim.max(axis=0).mean()) | |
| f1 = 2 * P * R / (P + R + 1e-9) | |
| return round(max(f1, 0.0), 4) | |
| except Exception as e: | |
| print(f"⚠️ bertscore_f1 error: {e}", flush=True) | |
| return 0.0 | |
| def semantic_similarity(answer: str, query: str) -> float: | |
| """Cosine similarity between answer embedding and query embedding.""" | |
| try: | |
| embs = _encode([answer, query]) | |
| score = float(embs[0] @ embs[1]) | |
| return round(max(score, 0.0), 4) | |
| except Exception as e: | |
| print(f"⚠️ semantic_similarity error: {e}", flush=True) | |
| return 0.0 | |
| def faithfulness(answer: str, context_sents: List[str], threshold: float = 0.35) -> float: | |
| """ | |
| Fraction of answer sentences whose max cosine-sim to any context sentence ≥ threshold. | |
| Threshold = 0.35 (not 0.40) so paraphrased but grounded sentences are counted. | |
| """ | |
| ans_sents = _sent_tokenize(answer) | |
| if not ans_sents or not context_sents: | |
| return 0.0 | |
| try: | |
| a_embs = _encode(ans_sents) | |
| c_embs = _encode(context_sents) | |
| sim = a_embs @ c_embs.T | |
| max_per_ans = sim.max(axis=1) | |
| faithful_count = int((max_per_ans >= threshold).sum()) | |
| return round(faithful_count / len(ans_sents), 4) | |
| except Exception as e: | |
| print(f"⚠️ faithfulness error: {e}", flush=True) | |
| return 0.0 | |
| def answer_relevance(answer: str, query: str) -> float: | |
| """Does the answer actually address what was asked?""" | |
| return semantic_similarity(answer, query) | |
| def context_recall(answer: str, context_sents: List[str], threshold: float = 0.35) -> float: | |
| """ | |
| Fraction of context sentences reflected in the answer. | |
| Mirrors RAGAS Context Recall but without ground-truth labels. | |
| """ | |
| ans_sents = _sent_tokenize(answer) | |
| if not ans_sents or not context_sents: | |
| return 0.0 | |
| try: | |
| a_embs = _encode(ans_sents) | |
| c_embs = _encode(context_sents) | |
| sim = a_embs @ c_embs.T | |
| max_per_ctx = sim.max(axis=0) | |
| recalled_count = int((max_per_ctx >= threshold).sum()) | |
| return round(recalled_count / len(context_sents), 4) | |
| except Exception as e: | |
| print(f"⚠️ context_recall error: {e}", flush=True) | |
| return 0.0 | |
| def compute_all_metrics(query: str, answer: str, context: str) -> Dict[str, float]: | |
| """ | |
| Tokenise context once, cap at 60 sentences (top-ranked chunks come first), | |
| then run all embedding-based metrics against that capped list. | |
| ROUGE uses the raw context string (pure token overlap, no matrices). | |
| """ | |
| ctx_sents_all = _sent_tokenize(context) | |
| ctx_sents = ctx_sents_all[:60] | |
| print(f"📐 Metrics: answer={len(_sent_tokenize(answer))} sents, " | |
| f"ctx={len(ctx_sents)}/{len(ctx_sents_all)} sents", flush=True) | |
| return { | |
| "BERTScore F1": bertscore_f1(answer, ctx_sents), | |
| "ROUGE-1": rouge_n(answer, context, n=1), | |
| "ROUGE-2": rouge_n(answer, context, n=2), | |
| "Semantic Similarity": semantic_similarity(answer, query), | |
| "Faithfulness": faithfulness(answer, ctx_sents), | |
| "Answer Relevance": answer_relevance(answer, query), | |
| "Context Recall": context_recall(answer, ctx_sents), | |
| } | |
| # ── Display helpers ─────────────────────────────────────────────────────────── | |
| _METRIC_DESCRIPTIONS = { | |
| "BERTScore F1": "Sentence-level semantic overlap F1 between answer and top context sentences.", | |
| "ROUGE-1": "Fraction of answer unigrams found in retrieved context (precision).", | |
| "ROUGE-2": "Fraction of answer bigrams found in retrieved context (precision).", | |
| "Semantic Similarity": "Cosine similarity between answer and question embeddings.", | |
| "Faithfulness": "Fraction of answer sentences semantically supported by the retrieved context.", | |
| "Answer Relevance": "How directly the answer addresses the original question.", | |
| "Context Recall": "Fraction of top context sentences reflected in the answer.", | |
| } | |
| _THRESHOLDS = { | |
| # (warn_below, ok_below, good_above) | |
| "BERTScore F1": (0.50, 0.65, 0.80), | |
| "ROUGE-1": (0.15, 0.30, 0.45), | |
| "ROUGE-2": (0.05, 0.15, 0.25), | |
| "Semantic Similarity": (0.40, 0.60, 0.75), | |
| "Faithfulness": (0.50, 0.70, 0.85), | |
| "Answer Relevance": (0.40, 0.60, 0.75), | |
| "Context Recall": (0.15, 0.30, 0.50), | |
| } | |
| def _colour(name: str, value: float) -> str: | |
| warn, ok, good = _THRESHOLDS.get(name, (0.3, 0.6, 0.8)) | |
| if value >= good: return "🟢" | |
| if value >= ok: return "🟡" | |
| return "🔴" | |
| def _bar(value: float, width: int = 20) -> str: | |
| filled = int(round(value * width)) | |
| return "█" * filled + "░" * (width - filled) | |
| def format_metrics_markdown(metrics: Dict[str, float]) -> str: | |
| lines = ["## 📊 Evaluation Metrics\n"] | |
| lines.append( | |
| "> Metrics are **reference-free** and computed against the retrieved context " | |
| "and original query — no labelled ground truth required.\n" | |
| ) | |
| lines.append("| Metric | Score | Bar | Status | Notes |") | |
| lines.append("|--------|------:|-----|--------|-------|") | |
| for name, value in metrics.items(): | |
| pct = f"{value:.2%}" | |
| bar = f"`{_bar(value)}`" | |
| icon = _colour(name, value) | |
| desc = _METRIC_DESCRIPTIONS.get(name, "") | |
| lines.append(f"| **{name}** | {pct} | {bar} | {icon} | {desc} |") | |
| lines.append("\n**Colour key:** 🟢 Good · 🟡 Acceptable · 🔴 Needs attention") | |
| return "\n".join(lines) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # GPU FUNCTION | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def llm_generate(messages: list) -> str: | |
| print("🔥 GPU acquired, running generation...", flush=True) | |
| model.to("cuda") | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=300, | |
| do_sample=False, | |
| repetition_penalty=1.05, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode( | |
| output_ids[0][input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| model.to("cpu") | |
| torch.cuda.empty_cache() | |
| print("✅ Generation complete.", flush=True) | |
| return response | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # RAG PIPELINE | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def process_query_stream(query: str): | |
| # ── Step 1: retrieval ──────────────────────────────────────────────────── | |
| yield ( | |
| "⏳ **Status:** 🔍 Retrieving relevant documents from VectorDB...\n\n---\n", | |
| "" | |
| ) | |
| retrieved = vectorstore.similarity_search(query, k=60) | |
| unique, seen = [], set() | |
| for doc in retrieved: | |
| pg = doc.metadata.get("page") | |
| if pg not in seen: | |
| unique.append(doc) | |
| seen.add(pg) | |
| # ── Step 2: rerank ─────────────────────────────────────────────────────── | |
| yield ( | |
| "⏳ **Status:** 📊 Reranking with CrossEncoder (CPU)...\n\n---\n", | |
| "" | |
| ) | |
| scores = rerank_docs(query, unique) | |
| scored = sorted(zip(unique, scores), key=lambda x: x[1], reverse=True) | |
| top_docs = [d for d, _ in scored[:5]] | |
| context = "\n\n".join(d.page_content for d in top_docs) | |
| source_pages = ", ".join(str(d.metadata.get("page")) for d in top_docs) | |
| # ── Step 3: generate ───────────────────────────────────────────────────── | |
| yield ( | |
| "⏳ **Status:** 🧠 Synthesizing with Phi-3 (ZeroGPU H200)...\n\n---\n", | |
| "" | |
| ) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a strict clinical assistant for the ESC Guidelines. Answer based ONLY on the context. " | |
| "CRITICAL RULES:\n" | |
| "1. Define scoring systems strictly.\n" | |
| "2. State clearly if a treatment is Class III.\n" | |
| "3. Do not hallucinate acronym meanings.\n" | |
| "4. If the prompt is not medical, output: 'This model only answers to medical questions.'" | |
| ), | |
| }, | |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{query}"}, | |
| ] | |
| answer = llm_generate(messages) | |
| answer_md = f"### ⚕️ Answer\n\n{answer}\n\n📄 **Source Pages:** {source_pages}\n" | |
| # ── Step 4: metrics ────────────────────────────────────────────────────── | |
| yield ( | |
| answer_md, | |
| "⏳ **Status:** 📐 Computing evaluation metrics (CPU)...\n" | |
| ) | |
| metrics = compute_all_metrics(query, answer, context) | |
| metrics_md = format_metrics_markdown(metrics) | |
| yield (answer_md, metrics_md) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # GRADIO UI | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def gradio_wrapper(query): | |
| if not query or not query.strip(): | |
| yield "⚠️ Please enter a valid question.", "" | |
| return | |
| yield from process_query_stream(query) | |
| phi_theme = gr.themes.Soft( | |
| primary_hue="teal", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| ).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| ) | |
| with gr.Blocks(theme=phi_theme, title="Cardiology AI Assistant") as demo: | |
| # ── Header ─────────────────────────────────────────────────────────────── | |
| gr.Markdown("# ⚕️ Cardiology AI Assistant (ESC 2024)") | |
| gr.Markdown("### ⚡ Powered by Microsoft Phi-3-Mini · ZeroGPU H200") | |
| gr.Markdown( | |
| "Ask questions based on the **2024 ESC Medical Guidelines**. " | |
| "Uses RAG with MedCPT embeddings, Cross-Encoder reranking, Phi-3 generation, " | |
| "and **live evaluation metrics**." | |
| ) | |
| # ── Input ──────────────────────────────────────────────────────────────── | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| input_text = gr.Textbox( | |
| label="Your Clinical Question", | |
| placeholder="e.g., What are the class I recommendations for anticoagulation in AF?", | |
| lines=3, | |
| ) | |
| with gr.Column(scale=1, min_width=160): | |
| submit_btn = gr.Button("🔍 Analyze Guidelines", variant="primary", size="lg") | |
| # ── Examples ───────────────────────────────────────────────────────────── | |
| gr.Examples( | |
| examples=[ | |
| "What are the class I recommendations for anticoagulation in AF?", | |
| "Summarize the treatment algorithm for chronic heart failure.", | |
| "What is the target LDL-C for very high-risk patients?", | |
| ], | |
| inputs=input_text, | |
| label="Example Questions", | |
| ) | |
| gr.Markdown("---") | |
| # ── Answer output (full width) ──────────────────────────────────────────── | |
| answer_output = gr.Markdown( | |
| label="Assistant Response", | |
| value="*Your answer will appear here after submission.*", | |
| ) | |
| gr.Markdown("---") | |
| # ── Metrics output (full width, below answer) ───────────────────────────── | |
| metrics_output = gr.Markdown( | |
| label="Evaluation Metrics", | |
| value="*Metrics will appear here once the answer is generated.*", | |
| ) | |
| gr.Markdown("---") | |
| # ── Metric legend ───────────────────────────────────────────────────────── | |
| with gr.Accordion("ℹ️ About the Evaluation Metrics", open=False): | |
| gr.Markdown(""" | |
| ### How each metric is computed | |
| | Metric | Method | Interpretation | | |
| |--------|--------|---------------| | |
| | **BERTScore F1** | Sentence-level cosine-sim F1 between answer sentences and top-60 context sentences using `all-MiniLM-L6-v2` (forced CPU) | Measures how semantically similar the answer is to the source context | | |
| | **ROUGE-1** | **Precision**: fraction of answer unigrams that appear in the retrieved context | Are the words the model used actually in the retrieved passages? | | |
| | **ROUGE-2** | **Precision**: fraction of answer bigrams that appear in the retrieved context | Are the phrases the model used actually in the retrieved passages? | | |
| | **Semantic Similarity** | Cosine similarity of full answer ↔ question embeddings | Does the answer embed in the same semantic space as the question? | | |
| | **Faithfulness** | Fraction of answer sentences with cosine-sim ≥ 0.35 to any context sentence | Are answer claims grounded in retrieved text? | | |
| | **Answer Relevance** | Cosine similarity of answer ↔ question embeddings | How directly does the answer respond to the question? | | |
| | **Context Recall** | Fraction of top-60 context sentences with cosine-sim ≥ 0.35 to any answer sentence | How much of the retrieved evidence is used in the answer? | | |
| > **Why precision for ROUGE?** The retrieved context is ~6,000 tokens; a correct ~60-token answer | |
| > has only ~4% unigram *recall* against that pool — even if every word came from the context. | |
| > Precision asks the right question: *"Did the model use words that actually appear in the retrieved passages?"* | |
| > **All metrics are reference-free** — they use the retrieved context and original query as the | |
| > reference signal, so no annotated ground-truth is needed. | |
| """) | |
| # ── Wire up ─────────────────────────────────────────────────────────────── | |
| submit_btn.click( | |
| fn=gradio_wrapper, | |
| inputs=input_text, | |
| outputs=[answer_output, metrics_output], | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |