QWEN3 / app.py
johnnydang88's picture
Update app.py
4d7bb7b verified
"""
Cardiology AI Assistant β€” Alibaba Qwen3-4B-Instruct
Hugging Face ZeroGPU Space
Includes: BERTScore F1, ROUGE-N, Semantic Similarity, Faithfulness, Answer Relevance, Context Recall
Same metric stack as the Llama-3 and Phi-3 versions β€” all fixes applied:
β€’ SentenceTransformer forced to CPU (prevents stale CUDA zero-vector bug)
β€’ ROUGE uses precision (overlap / answer_ngrams), not recall vs huge context
β€’ Context capped at 60 sentences before embedding (prevents OOM)
β€’ Per-metric try/except so one failure never kills the whole panel
"""
import os, gc, re, torch, warnings, pdfplumber
import numpy as np
import spaces
from collections import Counter
from typing import List, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from sentence_transformers import CrossEncoder, SentenceTransformer
import gradio as gr
warnings.filterwarnings("ignore")
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"
PDF_PATH = "./2024ESC-compressed.pdf"
# ══════════════════════════════════════════════════════════════════════════════
# MEDCPT EMBEDDINGS (CPU)
# ══════════════════════════════════════════════════════════════════════════════
class MedCPTEmbeddings(Embeddings):
def __init__(self, load_article_encoder: bool = True):
print("βš™οΈ Initializing MedCPT on CPU...", flush=True)
self.models = {
"qry_tok": AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder"),
"qry_mod": AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder"),
}
if load_article_encoder:
self.models["art_tok"] = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
self.models["art_mod"] = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
def embed_documents(self, texts):
all_embeddings = []
for i in range(0, len(texts), 8):
batch = texts[i: i + 8]
inputs = self.models["art_tok"](
batch, max_length=512, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
out = self.models["art_mod"](**inputs)
all_embeddings.extend(out.last_hidden_state[:, 0, :].tolist())
return all_embeddings
def embed_query(self, text):
inputs = self.models["qry_tok"](
[text], max_length=512, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
out = self.models["qry_mod"](**inputs)
return out.last_hidden_state[:, 0, :][0].tolist()
def unload_article_encoder(self):
if "art_mod" in self.models:
del self.models["art_mod"], self.models["art_tok"]
gc.collect()
# ══════════════════════════════════════════════════════════════════════════════
# STARTUP
# ══════════════════════════════════════════════════════════════════════════════
print("πŸ“‚ Loading PDF with pdfplumber...", flush=True)
docs = []
with pdfplumber.open(PDF_PATH) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text() or ""
tables = page.extract_tables()
table_str = ""
if tables:
for t in tables:
table_str += "\n" + "\n".join(
["| " + " | ".join([str(c).replace("\n", " ") if c else "" for c in row]) + " |"
for row in t]
)
docs.append(Document(
page_content=f"{text}\n{table_str}",
metadata={"page": i + 1, "source": os.path.basename(PDF_PATH)}
))
print(f"βœ… Loaded {len(docs)} pages.", flush=True)
print("βœ‚οΈ Splitting documents...", flush=True)
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
chunks = splitter.split_documents(docs)
print("🧠 Building MedCPT vector store (CPU)...", flush=True)
emb = MedCPTEmbeddings(load_article_encoder=True)
vectorstore = FAISS.from_documents(chunks, emb)
emb.unload_article_encoder()
print("βœ… Vector store ready.", flush=True)
print("βš–οΈ Loading CrossEncoder (CPU)...", flush=True)
reranker = CrossEncoder("BAAI/bge-reranker-base", device="cpu")
# Explicitly load on CPU β€” after ZeroGPU releases the GPU, auto-device detection
# can latch onto a stale CUDA context and silently return zero vectors.
print("πŸ“ Loading metrics SentenceTransformer (CPU)...", flush=True)
metrics_st = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
print("βœ… Metrics encoder ready.", flush=True)
print("πŸš€ Loading Qwen3-4B in float16 (CPU)...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, token=HF_TOKEN, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
print("βœ… Qwen3 ready (CPU). GPU borrowed per request via ZeroGPU.", flush=True)
# ══════════════════════════════════════════════════════════════════════════════
# MULTI-QUERY EXPANSION
# ══════════════════════════════════════════════════════════════════════════════
QUERY_EXPANSIONS = {
"AF-CARE": [
"AF-CARE framework atrial fibrillation treatment pillars",
"AF-CARE acronym components",
"comorbidity risk factor stroke rhythm rate AF management",
],
"heart failure": [
"heart failure treatment algorithm pillars",
"HF management guideline directed medical therapy",
"heart failure pharmacological treatment steps",
],
"LDL": [
"LDL-C target very high risk patients",
"low density lipoprotein cholesterol cardiovascular risk",
"lipid lowering therapy statin target ESC guidelines",
],
}
def expand_query(query: str) -> List[str]:
q_lower = query.lower()
for keyword, expansions in QUERY_EXPANSIONS.items():
if keyword.lower() in q_lower:
return [query] + expansions
return [query]
def retrieve_with_expansion(query: str, k_per_query: int = 10) -> List[Document]:
sub_queries = expand_query(query)
seen, merged = set(), []
for sq in sub_queries:
for doc in vectorstore.similarity_search(sq, k=k_per_query):
key = doc.page_content[:120]
if key not in seen:
seen.add(key)
merged.append(doc)
return merged
def rerank_docs(query: str, docs):
scores = reranker.predict([[query, d.page_content] for d in docs])
return scores
# ══════════════════════════════════════════════════════════════════════════════
# EVALUATION METRICS
# All reference-free β€” uses retrieved context + query as the reference signal.
# Identical implementation to the Llama-3 and Phi-3 versions for consistency.
# ══════════════════════════════════════════════════════════════════════════════
def _sent_tokenize(text: str) -> List[str]:
"""Lightweight sentence splitter β€” no NLTK required."""
sents = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sents if len(s.strip()) > 10]
def _encode(texts: List[str]) -> np.ndarray:
"""
Encode on CPU explicitly.
After ZeroGPU releases the GPU, SentenceTransformer's auto-device detection
can latch onto a stale CUDA context and return zero vectors.
Forcing CPU guarantees correct, non-zero embeddings every time.
"""
return metrics_st.encode(
texts,
normalize_embeddings=True,
show_progress_bar=False,
device="cpu",
convert_to_numpy=True,
)
def _ngrams(tokens: List[str], n: int) -> Counter:
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
def rouge_n(hypothesis: str, reference: str, n: int = 1) -> float:
"""
ROUGE-N precision: fraction of answer n-grams that appear in the context.
Using precision (not recall) because the context is ~6,000+ tokens β€” recall
of a ~60-token answer against that pool is always ~4% even for correct answers.
"""
hyp_tokens = hypothesis.lower().split()
ref_tokens = reference.lower().split()
hyp_ng = _ngrams(hyp_tokens, n)
ref_ng = _ngrams(ref_tokens, n)
overlap = sum((hyp_ng & ref_ng).values())
denom = sum(hyp_ng.values()) # precision: denominator = answer n-grams
return round(overlap / denom, 4) if denom > 0 else 0.0
def bertscore_f1(answer: str, context_sents: List[str]) -> float:
"""
Approximate BERTScore F1 via sentence-level embeddings.
P = mean max-cosine(answer_sent β†’ any context_sent)
R = mean max-cosine(context_sent β†’ any answer_sent)
F1 = harmonic mean(P, R)
Uses pre-tokenised, capped context sentences to avoid encoding 100+ sentences.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
P = float(sim.max(axis=1).mean())
R = float(sim.max(axis=0).mean())
f1 = 2 * P * R / (P + R + 1e-9)
return round(max(f1, 0.0), 4)
except Exception as e:
print(f"⚠️ bertscore_f1 error: {e}", flush=True)
return 0.0
def semantic_similarity(answer: str, query: str) -> float:
"""Cosine similarity between answer embedding and query embedding."""
try:
embs = _encode([answer, query])
score = float(embs[0] @ embs[1])
return round(max(score, 0.0), 4)
except Exception as e:
print(f"⚠️ semantic_similarity error: {e}", flush=True)
return 0.0
def faithfulness(answer: str, context_sents: List[str], threshold: float = 0.35) -> float:
"""
Fraction of answer sentences whose max cosine-sim to any context sentence β‰₯ threshold.
Threshold = 0.35 (not 0.40) so paraphrased but grounded sentences are counted.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
max_per_ans = sim.max(axis=1)
faithful_count = int((max_per_ans >= threshold).sum())
return round(faithful_count / len(ans_sents), 4)
except Exception as e:
print(f"⚠️ faithfulness error: {e}", flush=True)
return 0.0
def answer_relevance(answer: str, query: str) -> float:
"""Does the answer actually address what was asked?"""
return semantic_similarity(answer, query)
def context_recall(answer: str, context_sents: List[str], threshold: float = 0.35) -> float:
"""
Fraction of context sentences reflected in the answer.
Mirrors RAGAS Context Recall but without ground-truth labels.
"""
ans_sents = _sent_tokenize(answer)
if not ans_sents or not context_sents:
return 0.0
try:
a_embs = _encode(ans_sents)
c_embs = _encode(context_sents)
sim = a_embs @ c_embs.T
max_per_ctx = sim.max(axis=0)
recalled_count = int((max_per_ctx >= threshold).sum())
return round(recalled_count / len(context_sents), 4)
except Exception as e:
print(f"⚠️ context_recall error: {e}", flush=True)
return 0.0
def compute_all_metrics(query: str, answer: str, context: str) -> Dict[str, float]:
"""
Tokenise context once, cap at 60 sentences (top-ranked chunks come first),
then run all embedding-based metrics against that capped list.
ROUGE uses the raw context string (pure token overlap, no matrices).
"""
ctx_sents_all = _sent_tokenize(context)
ctx_sents = ctx_sents_all[:60]
print(f"πŸ“ Metrics: answer={len(_sent_tokenize(answer))} sents, "
f"ctx={len(ctx_sents)}/{len(ctx_sents_all)} sents", flush=True)
return {
"BERTScore F1": bertscore_f1(answer, ctx_sents),
"ROUGE-1": rouge_n(answer, context, n=1),
"ROUGE-2": rouge_n(answer, context, n=2),
"Semantic Similarity": semantic_similarity(answer, query),
"Faithfulness": faithfulness(answer, ctx_sents),
"Answer Relevance": answer_relevance(answer, query),
"Context Recall": context_recall(answer, ctx_sents),
}
# ── Display helpers ───────────────────────────────────────────────────────────
_METRIC_DESCRIPTIONS = {
"BERTScore F1": "Sentence-level semantic overlap F1 between answer and top context sentences.",
"ROUGE-1": "Fraction of answer unigrams found in retrieved context (precision).",
"ROUGE-2": "Fraction of answer bigrams found in retrieved context (precision).",
"Semantic Similarity": "Cosine similarity between answer and question embeddings.",
"Faithfulness": "Fraction of answer sentences semantically supported by the retrieved context.",
"Answer Relevance": "How directly the answer addresses the original question.",
"Context Recall": "Fraction of top context sentences reflected in the answer.",
}
_THRESHOLDS = {
# (warn_below, ok_below, good_above)
"BERTScore F1": (0.50, 0.65, 0.80),
"ROUGE-1": (0.15, 0.30, 0.45),
"ROUGE-2": (0.05, 0.15, 0.25),
"Semantic Similarity": (0.40, 0.60, 0.75),
"Faithfulness": (0.50, 0.70, 0.85),
"Answer Relevance": (0.40, 0.60, 0.75),
"Context Recall": (0.15, 0.30, 0.50),
}
def _colour(name: str, value: float) -> str:
warn, ok, good = _THRESHOLDS.get(name, (0.3, 0.6, 0.8))
if value >= good: return "🟒"
if value >= ok: return "🟑"
return "πŸ”΄"
def _bar(value: float, width: int = 20) -> str:
filled = int(round(value * width))
return "β–ˆ" * filled + "β–‘" * (width - filled)
def format_metrics_markdown(metrics: Dict[str, float]) -> str:
lines = ["## πŸ“Š Evaluation Metrics\n"]
lines.append(
"> Metrics are **reference-free** and computed against the retrieved context "
"and original query β€” no labelled ground truth required.\n"
)
lines.append("| Metric | Score | Bar | Status | Notes |")
lines.append("|--------|------:|-----|--------|-------|")
for name, value in metrics.items():
pct = f"{value:.2%}"
bar = f"`{_bar(value)}`"
icon = _colour(name, value)
desc = _METRIC_DESCRIPTIONS.get(name, "")
lines.append(f"| **{name}** | {pct} | {bar} | {icon} | {desc} |")
lines.append("\n**Colour key:** 🟒 Good Β· 🟑 Acceptable Β· πŸ”΄ Needs attention")
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# GPU FUNCTION
# ══════════════════════════════════════════════════════════════════════════════
@spaces.GPU(duration=120)
def llm_generate(messages: list) -> str:
print("πŸ”₯ GPU acquired, running generation...", flush=True)
model.to("cuda")
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to("cuda")
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
)
input_len = inputs["input_ids"].shape[1]
answer = tokenizer.decode(generated_ids[0][input_len:], skip_special_tokens=True)
model.to("cpu")
torch.cuda.empty_cache()
print("βœ… Generation complete.", flush=True)
return answer
# ══════════════════════════════════════════════════════════════════════════════
# RAG PIPELINE
# ══════════════════════════════════════════════════════════════════════════════
SYSTEM_PROMPT = (
"You are a medical expert assistant specialising in cardiology. "
"Answer the user's question using ONLY the context provided below. "
"If the context contains a list, framework, or set of pillars, enumerate ALL of them explicitly β€” "
"do NOT say they are not mentioned if they appear anywhere in the context. "
"Format list-type answers as numbered or lettered points. "
"Always cite the page number(s) from the context where the information appears. "
"If the answer is genuinely not in the context, say you don't know."
)
def rag_query_stream(query: str):
# ── Step 1: retrieval ────────────────────────────────────────────────────
yield (
"⏳ **Status:** πŸ” Retrieving relevant documents (multi-query expansion)...\n\n---\n",
""
)
candidates = retrieve_with_expansion(query, k_per_query=10)
# ── Step 2: rerank ───────────────────────────────────────────────────────
yield (
"⏳ **Status:** πŸ“Š Reranking with CrossEncoder (CPU)...\n\n---\n",
""
)
scores = rerank_docs(query, candidates)
ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
top_docs = [doc for _, doc in ranked[:8]]
context = "\n\n".join(
f"[Page {d.metadata.get('page', '?')}]\n{d.page_content}" for d in top_docs
)
pages = ", ".join(str(d.metadata.get("page", "?")) for d in top_docs)
# ── Step 3: generate ─────────────────────────────────────────────────────
yield (
"⏳ **Status:** 🧠 Generating with Qwen3 (ZeroGPU H200)...\n\n---\n",
""
)
messages = [
{
"role": "system",
"content": f"{SYSTEM_PROMPT}\n\nContext:\n{context}",
},
{"role": "user", "content": query},
]
answer = llm_generate(messages)
answer_md = f"### 🌌 Answer\n\n{answer}\n\nπŸ“„ **Source Pages:** {pages}\n"
# ── Step 4: metrics ──────────────────────────────────────────────────────
yield (
answer_md,
"⏳ **Status:** πŸ“ Computing evaluation metrics (CPU)...\n"
)
metrics = compute_all_metrics(query, answer, context)
metrics_md = format_metrics_markdown(metrics)
yield (answer_md, metrics_md)
# ══════════════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════════════
def gradio_wrapper(query):
if not query or not query.strip():
yield "⚠️ Please enter a valid question.", ""
return
yield from rag_query_stream(query)
qwen_theme = gr.themes.Soft(
primary_hue="purple",
secondary_hue="indigo",
neutral_hue="slate",
).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_700",
)
with gr.Blocks(theme=qwen_theme, title="Cardiology AI Assistant") as demo:
# ── Header ───────────────────────────────────────────────────────────────
gr.Markdown("# 🌌 Cardiology AI Assistant (ESC 2024)")
gr.Markdown("### ⚑ Powered by Alibaba Qwen3-4B · ZeroGPU H200")
gr.Markdown(
"Ask questions based on the **2024 ESC Medical Guidelines**. "
"Uses RAG with MedCPT embeddings, multi-query expansion, CrossEncoder reranking, "
"Qwen3-4B generation, and **live evaluation metrics**."
)
# ── Input ────────────────────────────────────────────────────────────────
with gr.Row():
with gr.Column(scale=4):
input_text = gr.Textbox(
label="Your Clinical Question",
placeholder="e.g., What are the four treatment pillars of AF-CARE?",
lines=3,
)
with gr.Column(scale=1, min_width=160):
submit_btn = gr.Button("πŸ” Analyze Guidelines", variant="primary", size="lg")
# ── Examples ─────────────────────────────────────────────────────────────
gr.Examples(
examples=[
"What are the four treatment pillars of the AF-CARE framework?",
"What are the class I recommendations for anticoagulation in AF?",
"Summarize the treatment algorithm for chronic heart failure.",
"What is the target LDL-C for very high-risk patients?",
],
inputs=input_text,
label="Example Questions",
)
gr.Markdown("---")
# ── Answer output (full width) ────────────────────────────────────────────
answer_output = gr.Markdown(
label="Assistant Response",
value="*Your answer will appear here after submission.*",
)
gr.Markdown("---")
# ── Metrics output (full width, below answer) ─────────────────────────────
metrics_output = gr.Markdown(
label="Evaluation Metrics",
value="*Metrics will appear here once the answer is generated.*",
)
gr.Markdown("---")
# ── Metric legend ─────────────────────────────────────────────────────────
with gr.Accordion("ℹ️ About the Evaluation Metrics", open=False):
gr.Markdown("""
### How each metric is computed
| Metric | Method | Interpretation |
|--------|--------|---------------|
| **BERTScore F1** | Sentence-level cosine-sim F1 between answer sentences and top-60 context sentences using `all-MiniLM-L6-v2` (forced CPU) | Measures how semantically similar the answer is to the source context |
| **ROUGE-1** | **Precision**: fraction of answer unigrams that appear in the retrieved context | Are the words the model used actually in the retrieved passages? |
| **ROUGE-2** | **Precision**: fraction of answer bigrams that appear in the retrieved context | Are the phrases the model used actually in the retrieved passages? |
| **Semantic Similarity** | Cosine similarity of full answer ↔ question embeddings | Does the answer embed in the same semantic space as the question? |
| **Faithfulness** | Fraction of answer sentences with cosine-sim β‰₯ 0.35 to any context sentence | Are answer claims grounded in retrieved text? |
| **Answer Relevance** | Cosine similarity of answer ↔ question embeddings | How directly does the answer respond to the question? |
| **Context Recall** | Fraction of top-60 context sentences with cosine-sim β‰₯ 0.35 to any answer sentence | How much of the retrieved evidence is used in the answer? |
> **Why precision for ROUGE?** The retrieved context is ~8,000 tokens; a correct ~60-token answer
> has only ~4% unigram *recall* against that pool β€” even if every word came from the context.
> Precision asks the right question: *"Did the model use words that actually appear in the retrieved passages?"*
> **All metrics are reference-free** β€” they use the retrieved context and original query as the
> reference signal, so no annotated ground-truth is needed.
""")
# ── Wire up ───────────────────────────────────────────────────────────────
submit_btn.click(
fn=gradio_wrapper,
inputs=input_text,
outputs=[answer_output, metrics_output],
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)