saandip5's picture
Update app.py
1328127 verified
import os
import re
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional
from rank_bm25 import BM25Okapi
import gradio as gr
from huggingface_hub import InferenceClient
# 1. Auth & Models
HF_TOKEN = os.getenv("HF_TOKEN")
PRIMARY_MODEL = "google/gemma-4-31B-it"
FALLBACK_MODEL = "google/gemma-2-2b-it"
client = InferenceClient(token=HF_TOKEN)
# 2. RAG Utilities
@dataclass(frozen=True)
class Chunk:
chunk_id: str
source: str
page: Optional[int]
text: str
def _clean(s: str) -> str:
return " ".join((s or "").replace("\u00a0", " ").split())
def _tokenize(text: str) -> List[str]:
text = "".join(ch.lower() if (ch.isalnum() or ch.isspace()) else " " for ch in text)
return [t for t in text.split() if len(t) > 1]
def load_text(path: Path) -> List[Chunk]:
"""Split text files into paragraph-level chunks for better BM25 matching."""
try:
raw = path.read_text(encoding="utf-8", errors="ignore")
paragraphs = re.split(r'\n\s*\n|\n(?=Test Case)', raw)
chunks = []
for i, para in enumerate(paragraphs):
cleaned = _clean(para)
if len(cleaned) > 20: # skip tiny fragments
chunks.append(Chunk(
chunk_id=f"{path.name}:c{i+1}",
source=path.name,
page=None,
text=cleaned
))
return chunks
except Exception:
return []
def grounding_score(answer: str, chunks: List[Chunk]) -> float:
sentences = re.split(r'(?<=[.!?])\s+', answer)
if not sentences:
return 0.0
cited = [s for s in sentences if "[" in s and "]" in s]
all_chunk_text = " ".join(c.text.lower() for c in chunks)
words = set(re.findall(r"\w+", answer.lower()))
ref_words = set(re.findall(r"\w+", all_chunk_text))
overlap = len(words & ref_words) / max(len(words), 1)
return (len(cited) / len(sentences) * 0.7) + (overlap * 0.3)
# 3. Global State
CURRENT_INDEX = None
CURRENT_CHUNKS: List[Chunk] = []
def process_uploads(files):
"""Build BM25 index from uploaded files."""
global CURRENT_INDEX, CURRENT_CHUNKS
if not files:
CURRENT_INDEX, CURRENT_CHUNKS = None, []
return "No files uploaded."
all_chunks: List[Chunk] = []
for f in files:
# Gradio 6 returns file paths as strings directly
fp = f if isinstance(f, str) else getattr(f, "name", str(f))
p = Path(fp)
print(f"[DEBUG] Processing: {p} (exists={p.exists()}, size={p.stat().st_size if p.exists() else 'N/A'})")
if p.suffix.lower() != ".txt":
return f"Only .txt files are supported. Unsupported file: **{p.name}**"
all_chunks.extend(load_text(p))
if not all_chunks:
return f" No text extracted from {len(files)} file(s). Check file content."
corpus = [_tokenize(c.text) for c in all_chunks]
CURRENT_INDEX = BM25Okapi(corpus)
CURRENT_CHUNKS = all_chunks
preview = "\n".join(f" • Chunk {i+1}: {c.text[:80]}…" for i, c in enumerate(all_chunks[:3]))
return f" Indexed **{len(all_chunks)} chunks** from {len(files)} file(s).\n\n{preview}"
# 4. Agentic Pipeline with Reasoning Trace
def agentic_answer(question: str):
if not HF_TOKEN:
return " Missing HF_TOKEN! Add it in Space Settings → Secrets.", "", "", "", ""
if CURRENT_INDEX is None:
return "Upload document(s) first!", "", "", "", ""
# Step 1: Retrieval (BM25)
q_tokens = _tokenize(question)
scores = CURRENT_INDEX.get_scores(q_tokens)
top_indices = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:5]
retrieved = [CURRENT_CHUNKS[i] for i, s in top_indices if s > 0]
print(f"[DEBUG] Query tokens: {q_tokens[:10]}...")
print(f"[DEBUG] Top scores: {[(i, f'{s:.2f}') for i, s in top_indices[:5]]}")
print(f"[DEBUG] Retrieved {len(retrieved)} chunks")
if not retrieved:
return "No relevant info found. Try rephrasing your question.", "", "", "", ""
context = "\n".join(f"[{i+1}] {c.text}" for i, c in enumerate(retrieved))
# Step 2: Generate Answer with Reasoning Trace
messages = [
{"role": "system", "content": (
"You are a HIGH-TRUST health information assistant. "
"Before providing your final answer, you MUST provide a <reasoning_trace> section. "
"In this section:\n"
"1. Analyze the key symptoms or topics mentioned in the user query.\n"
"2. Map these symptoms/topics to the provided DOCUMENT chunks with [1] [2] etc.\n"
"3. Check for safety flags (emergency symptoms, requests for diagnosis/prescription).\n"
"4. Summarize your confidence level based on source coverage.\n"
"Then close with </reasoning_trace> and provide your 'Final Answer'.\n\n"
"RULES FOR THE FINAL ANSWER:\n"
"- NO personal diagnosis. NO prescribing medication.\n"
"- Use language like 'Sources suggest...' or 'According to [1]...'.\n"
"- MANDATORY: Include [1] [2] etc. after every factual claim.\n"
"- If the question sounds like a medical emergency, advise calling emergency services."
)},
{"role": "user", "content": f"SOURCES:\n{context}\n\nQUESTION: {question}\n\nPlease analyze this step-by-step."}
]
# Try primary → fallback
answer_text, model_used = "", ""
for model_id, label in [(PRIMARY_MODEL, "Gemma 4-31B"), (FALLBACK_MODEL, "Gemma 2-2B")]:
try:
resp = client.chat_completion(model=model_id, messages=messages, max_tokens=700, temperature=0.1)
answer_text = resp.choices[0].message.content
model_used = label
break
except Exception as e:
print(f"[DEBUG] {label} failed: {e}")
continue
if not answer_text:
return " Both models unavailable. Check your HF_TOKEN and model access.", "", "", "", ""
# Step 3: Split Reasoning Trace from Final Answer
if "<reasoning_trace>" in answer_text:
parts = answer_text.split("</reasoning_trace>")
trace = parts[0].replace("<reasoning_trace>", "").strip()
final_ans = parts[1].strip() if len(parts) > 1 else "Analysis complete."
else:
trace = " *Standard inference performed — no explicit trace returned by model.*"
final_ans = answer_text
# Step 4: Compute Trust & Safety Badges
score = grounding_score(final_ans, retrieved)
is_refusal = any(kw in final_ans.lower() for kw in ["not a doctor", "professional", "emergency", "consult"])
trust_cls = "high-trust" if score > 0.5 else "low-trust"
safety_cls = "safety-refusal" if is_refusal else "safety-pass"
safety_label = " Safety Advisory" if is_refusal else " Safety: PASSED"
trust_html = f'<div class="badge {trust_cls}">Trust: {score:.2f} ({model_used})</div>'
safety_html = f'<div class="badge {safety_cls}">{safety_label}</div>'
sources_md = "\n---\n".join(f"**[{i+1}] {c.source}**\n> {c.text}" for i, c in enumerate(retrieved))
return final_ans, trace, sources_md, trust_html, safety_html
# 5. UI
CUSTOM_CSS = """
:root { --p: #2c5282; --bg: #0f1117; --accent: #805ad5; }
.gradio-container { background: var(--bg) !important; color: #e2e8f0 !important; font-family: system-ui, sans-serif; }
#header {
background: linear-gradient(135deg, var(--p) 0%, var(--accent) 100%);
color: white; padding: 28px 24px; border-radius: 14px;
margin-bottom: 24px; text-align: center;
box-shadow: 0 6px 24px rgba(44,82,130,0.4);
}
#header h1, #header h3, #header * { color: #fff !important; }
.gradio-container *, .gradio-container label,
.gradio-container h1, .gradio-container h2, .gradio-container h3,
.gradio-container li, .gradio-container td, .gradio-container p,
.gradio-container span { color: #e2e8f0 !important; }
#ans *, #src * { color: #e2e8f0 !important; }
blockquote { border-left: 3px solid var(--accent); padding-left: 12px; color: #cbd5e0 !important; }
.badge {
padding: 6px 16px; border-radius: 20px; font-weight: 700;
display: inline-block; margin-right: 8px; margin-bottom: 8px;
font-size: 0.85rem; letter-spacing: 0.02em;
}
.high-trust, .safety-pass { background: #22543d !important; color: #c6f6d5 !important; }
.low-trust, .safety-refusal { background: #742a2a !important; color: #fed7d7 !important; }
#trace-accordion { border-left: 3px solid var(--accent); background: #1a1f2e !important;
border-radius: 10px; margin: 10px 0; }
#trace-accordion * { color: #e2e8f0 !important; }
textarea, input[type="text"] { background: #1a202c !important; color: #e2e8f0 !important;
border: 1px solid #4a5568 !important; border-radius: 8px !important; }
.panel, .form { background: #1a202c !important; border-color: #2d3748 !important; }
button.primary { background: linear-gradient(135deg, #2c5282, #805ad5) !important;
border: none !important; font-weight: 700 !important; border-radius: 8px !important; }
"""
with gr.Blocks(css=CUSTOM_CSS, title="Medical RAG Studio", theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="header"):
gr.Markdown("# Medical RAG Studio\n### High-Trust Agentic Intelligence for Health Discovery")
with gr.Row():
with gr.Column(scale=2):
files = gr.File(label="Upload Knowledge Base (.txt)", file_count="multiple", file_types=[".txt"])
status = gr.Markdown("No documents indexed.")
q = gr.Textbox(label="Research Question", lines=3,
placeholder="e.g., I feel nauseous with stomach cramps after eating...")
btn = gr.Button("Analyze with Gemma", variant="primary")
# Demo Questions for Judges
gr.Markdown("### Try These Demo Queries")
gr.Examples(
examples=[
["What are the symptoms of a common cold according to the documents?"],
["What medication should I take for my headache?"],
["I have crushing chest pain spreading to my jaw and trouble breathing"],
["I ate dinner four hours ago and now I feel nauseous with stomach cramps and vomiting"],
["My eyes are itchy and watery every spring, what could this be?"],
],
inputs=[q],
label="",
examples_per_page=5,
)
with gr.Column(scale=3):
with gr.Row():
t_badge = gr.HTML('<div class="badge low-trust">Trust: 0.00</div>')
s_badge = gr.HTML('<div class="badge safety-pass">Safety: Idle</div>')
ans = gr.Markdown(elem_id="ans")
# Reasoning Trace Accordion Explainable AI for judges
with gr.Accordion(" View Gemma's Reasoning Trace", open=False, elem_id="trace-accordion"):
trace_output = gr.Markdown(
value="*Reasoning trace will appear here after analysis...*"
)
src = gr.Markdown(label="Evidence Dashboard", elem_id="src")
files.change(process_uploads, [files], [status])
btn.click(agentic_answer, [q], [ans, trace_output, src, t_badge, s_badge])
if __name__ == "__main__":
demo.queue().launch()