johnnydang88 commited on
Commit
fb4b4a7
Β·
verified Β·
1 Parent(s): 262e6e7
Files changed (3) hide show
  1. README.md +26 -0
  2. app.py +212 -0
  3. requirements.txt +11 -0
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Cardiology AI - Qwen
3
+ emoji: 🌌
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: "5.25.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ hardware: zero-a10g
11
+ secrets:
12
+ - HF_TOKEN
13
+ ---
14
+
15
+ # 🌌 Cardiology AI Assistant β€” Qwen3-4B
16
+
17
+ RAG-based cardiology Q&A over the **2024 ESC Guidelines**.
18
+
19
+ - **Retriever:** MedCPT (CPU)
20
+ - **Reranker:** BAAI/bge-reranker-base
21
+ - **Generator:** Qwen/Qwen3-4B-Instruct-2507 (ZeroGPU)
22
+
23
+ ## Setup
24
+ 1. Upload `2024ESC-compressed.pdf` to the Space repo root.
25
+ 2. Add `HF_TOKEN` in **Settings β†’ Secrets** (Qwen is a gated model).
26
+ 3. Hardware: ZeroGPU (requires HF Pro).
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cardiology AI Assistant β€” Alibaba Qwen3-4B-Instruct
3
+ Hugging Face ZeroGPU Space (free shared A100)
4
+
5
+ ZeroGPU rules applied:
6
+ - No bitsandbytes quantization
7
+ - Model loads to CPU at startup in float16
8
+ - @spaces.GPU decorator borrows GPU only during inference
9
+ """
10
+
11
+ import os, gc, torch, warnings
12
+ import spaces # ← ZeroGPU magic
13
+ from typing import List
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
15
+ from langchain_community.document_loaders import PyMuPDFLoader
16
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
17
+ from langchain_community.vectorstores import FAISS
18
+ from langchain_core.embeddings import Embeddings
19
+ from sentence_transformers import CrossEncoder
20
+ import gradio as gr
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
+ MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"
26
+ PDF_PATH = "./2024ESC-compressed.pdf"
27
+
28
+ # ══════════════════════════════════════════════════════════════════════════════
29
+ # MEDCPT EMBEDDINGS (CPU)
30
+ # ══════════════════════════════════════════════════════════════════════════════
31
+ class MedCPTEmbeddings(Embeddings):
32
+ def __init__(self, load_article_encoder: bool = True):
33
+ print("βš™οΈ Initializing MedCPT on CPU...")
34
+ self.models = {
35
+ "qry_tok": AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder"),
36
+ "qry_mod": AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder"),
37
+ }
38
+ if load_article_encoder:
39
+ self.models["art_tok"] = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")
40
+ self.models["art_mod"] = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
41
+
42
+ def embed_documents(self, texts):
43
+ all_embeddings = []
44
+ for i in range(0, len(texts), 8):
45
+ batch = texts[i: i + 8]
46
+ inputs = self.models["art_tok"](
47
+ batch, max_length=512, padding=True, truncation=True, return_tensors="pt"
48
+ )
49
+ with torch.no_grad():
50
+ out = self.models["art_mod"](**inputs)
51
+ all_embeddings.extend(out.last_hidden_state[:, 0, :].tolist())
52
+ return all_embeddings
53
+
54
+ def embed_query(self, text):
55
+ inputs = self.models["qry_tok"](
56
+ [text], max_length=512, padding=True, truncation=True, return_tensors="pt"
57
+ )
58
+ with torch.no_grad():
59
+ out = self.models["qry_mod"](**inputs)
60
+ return out.last_hidden_state[:, 0, :][0].tolist()
61
+
62
+ def unload_article_encoder(self):
63
+ if "art_mod" in self.models:
64
+ del self.models["art_mod"], self.models["art_tok"]
65
+ gc.collect()
66
+
67
+ # ══════════════════════════════════════════════════════════════════════════════
68
+ # STARTUP
69
+ # ══════════════════════════════════════════════════════════════════════════════
70
+ print("πŸ“‚ Loading PDF with PyMuPDF...")
71
+ loader = PyMuPDFLoader(PDF_PATH)
72
+ documents = loader.load()
73
+ print(f"βœ… Loaded {len(documents)} pages.")
74
+
75
+ print("βœ‚οΈ Splitting...")
76
+ splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
77
+ chunks = splitter.split_documents(documents)
78
+
79
+ print("🧠 Building MedCPT vector store (CPU)...")
80
+ emb = MedCPTEmbeddings(load_article_encoder=True)
81
+ vectorstore = FAISS.from_documents(chunks, emb)
82
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 20})
83
+ emb.unload_article_encoder()
84
+ print("βœ… Vector store ready.")
85
+
86
+ print("βš–οΈ Loading CrossEncoder (CPU init)...")
87
+ reranker = CrossEncoder("BAAI/bge-reranker-base", device="cpu")
88
+
89
+ print("πŸš€ Loading Qwen3-4B in float16 (CPU)...")
90
+ tokenizer = AutoTokenizer.from_pretrained(
91
+ MODEL_NAME, token=HF_TOKEN, trust_remote_code=True
92
+ )
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ MODEL_NAME,
95
+ token=HF_TOKEN,
96
+ torch_dtype=torch.float16,
97
+ low_cpu_mem_usage=True,
98
+ trust_remote_code=True,
99
+ )
100
+ model.eval()
101
+ print("βœ… Qwen3 ready (CPU). GPU borrowed per request via ZeroGPU.")
102
+
103
+ # ══════════════════════════════════════════════════════════════════════════════
104
+ # GPU FUNCTIONS
105
+ # ═════════════════════════════════════════════════════════════════��════════════
106
+
107
+ @spaces.GPU
108
+ def rerank_docs(query: str, docs):
109
+ reranker.model.to("cuda")
110
+ scores = reranker.predict([[query, d.page_content] for d in docs])
111
+ reranker.model.to("cpu")
112
+ torch.cuda.empty_cache()
113
+ return scores
114
+
115
+ @spaces.GPU
116
+ def llm_generate(messages: list) -> str:
117
+ model.to("cuda")
118
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
119
+ inputs = tokenizer(text, return_tensors="pt").to("cuda")
120
+ with torch.no_grad():
121
+ generated_ids = model.generate(
122
+ **inputs,
123
+ max_new_tokens=512,
124
+ do_sample=True,
125
+ temperature=0.7,
126
+ top_p=0.8,
127
+ top_k=20,
128
+ repetition_penalty=1.05,
129
+ )
130
+ input_len = inputs["input_ids"].shape[1]
131
+ answer = tokenizer.decode(generated_ids[0][input_len:], skip_special_tokens=True)
132
+ del inputs, generated_ids
133
+ model.to("cpu")
134
+ torch.cuda.empty_cache()
135
+ return answer
136
+
137
+ # ══════════════════════════════════════════════════════════════════════════════
138
+ # RAG PIPELINE
139
+ # ══════════════════════════════════════════════════════════════════════════════
140
+ def rag_query_stream(query: str):
141
+ yield "⏳ **Status:** πŸ” Retrieving documents from VectorDB...\n\n---\n"
142
+ candidates = retriever.invoke(query)
143
+
144
+ yield "⏳ **Status:** πŸ“Š Reranking with CrossEncoder (ZeroGPU)...\n\n---\n"
145
+ scores = rerank_docs(query, candidates)
146
+ ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
147
+ top_docs = [doc for _, doc in ranked[:4]]
148
+ context = "\n\n".join(d.page_content for d in top_docs)
149
+ pages = ", ".join(str(d.metadata.get("page", "?")) for d in top_docs)
150
+
151
+ yield "⏳ **Status:** 🧠 Generating with Qwen3 (ZeroGPU A100)...\n\n---\n"
152
+ messages = [
153
+ {
154
+ "role": "system",
155
+ "content": (
156
+ "You are a medical expert assistant specialising in cardiology. "
157
+ "Answer the user's question using ONLY the context provided. "
158
+ "If the answer is not in the context, say you don't know.\n\n"
159
+ f"Context:\n{context}"
160
+ ),
161
+ },
162
+ {"role": "user", "content": query},
163
+ ]
164
+ answer = llm_generate(messages)
165
+ yield f"### 🌌 Answer\n\n{answer}\n\nπŸ“„ **Source Pages:** {pages}\n"
166
+
167
+ # ══════════════════════════════════════════════════════════════════════════════
168
+ # GRADIO UI
169
+ # ══════════════════════════════════════════════════════════════════════════════
170
+ def gradio_wrapper(query):
171
+ if not query or not query.strip():
172
+ yield "⚠️ Please enter a valid question."
173
+ return
174
+ yield from rag_query_stream(query)
175
+
176
+ qwen_theme = gr.themes.Soft(
177
+ primary_hue="violet",
178
+ secondary_hue="indigo",
179
+ neutral_hue="slate",
180
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "sans-serif"],
181
+ ).set(
182
+ button_primary_background_fill="*primary_600",
183
+ button_primary_background_fill_hover="*primary_700",
184
+ )
185
+
186
+ with gr.Blocks(theme=qwen_theme) as demo:
187
+ gr.Markdown("# 🌌 Cardiology AI Assistant (ESC 2024)")
188
+ gr.Markdown("### ⚑ Powered by Alibaba Qwen3-4B · HF ZeroGPU")
189
+ gr.Markdown(
190
+ "Ask questions based on the **2024 ESC Medical Guidelines**. "
191
+ "Uses RAG with MedCPT embeddings, Cross-Encoder reranking, and Qwen3-4B generation."
192
+ )
193
+ with gr.Row():
194
+ with gr.Column():
195
+ input_text = gr.Textbox(
196
+ label="Your Clinical Question",
197
+ placeholder="e.g., What are the class I recommendations for anticoagulation in AF?",
198
+ lines=3,
199
+ )
200
+ submit_btn = gr.Button("Analyze Guidelines", variant="primary")
201
+ output_text = gr.Markdown(label="Assistant Response")
202
+ gr.Examples(
203
+ examples=[
204
+ "What are the class I recommendations for anticoagulation in AF?",
205
+ "Summarize the treatment algorithm for chronic heart failure.",
206
+ "What is the target LDL-C for very high-risk patients?",
207
+ ],
208
+ inputs=input_text,
209
+ )
210
+ submit_btn.click(gradio_wrapper, inputs=input_text, outputs=output_text)
211
+
212
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.41.2
2
+ accelerate
3
+ langchain
4
+ langchain-community
5
+ langchain-core
6
+ langchain-text-splitters
7
+ faiss-cpu
8
+ sentence-transformers
9
+ pymupdf
10
+ torch
11
+ huggingface_hub