Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| from rank_bm25 import BM25Okapi | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| # 1. Auth & Models | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| PRIMARY_MODEL = "google/gemma-4-31B-it" | |
| FALLBACK_MODEL = "google/gemma-2-2b-it" | |
| client = InferenceClient(token=HF_TOKEN) | |
| # 2. RAG Utilities | |
| class Chunk: | |
| chunk_id: str | |
| source: str | |
| page: Optional[int] | |
| text: str | |
| def _clean(s: str) -> str: | |
| return " ".join((s or "").replace("\u00a0", " ").split()) | |
| def _tokenize(text: str) -> List[str]: | |
| text = "".join(ch.lower() if (ch.isalnum() or ch.isspace()) else " " for ch in text) | |
| return [t for t in text.split() if len(t) > 1] | |
| def load_text(path: Path) -> List[Chunk]: | |
| """Split text files into paragraph-level chunks for better BM25 matching.""" | |
| try: | |
| raw = path.read_text(encoding="utf-8", errors="ignore") | |
| paragraphs = re.split(r'\n\s*\n|\n(?=Test Case)', raw) | |
| chunks = [] | |
| for i, para in enumerate(paragraphs): | |
| cleaned = _clean(para) | |
| if len(cleaned) > 20: # skip tiny fragments | |
| chunks.append(Chunk( | |
| chunk_id=f"{path.name}:c{i+1}", | |
| source=path.name, | |
| page=None, | |
| text=cleaned | |
| )) | |
| return chunks | |
| except Exception: | |
| return [] | |
| def grounding_score(answer: str, chunks: List[Chunk]) -> float: | |
| sentences = re.split(r'(?<=[.!?])\s+', answer) | |
| if not sentences: | |
| return 0.0 | |
| cited = [s for s in sentences if "[" in s and "]" in s] | |
| all_chunk_text = " ".join(c.text.lower() for c in chunks) | |
| words = set(re.findall(r"\w+", answer.lower())) | |
| ref_words = set(re.findall(r"\w+", all_chunk_text)) | |
| overlap = len(words & ref_words) / max(len(words), 1) | |
| return (len(cited) / len(sentences) * 0.7) + (overlap * 0.3) | |
| # 3. Global State | |
| CURRENT_INDEX = None | |
| CURRENT_CHUNKS: List[Chunk] = [] | |
| def process_uploads(files): | |
| """Build BM25 index from uploaded files.""" | |
| global CURRENT_INDEX, CURRENT_CHUNKS | |
| if not files: | |
| CURRENT_INDEX, CURRENT_CHUNKS = None, [] | |
| return "No files uploaded." | |
| all_chunks: List[Chunk] = [] | |
| for f in files: | |
| # Gradio 6 returns file paths as strings directly | |
| fp = f if isinstance(f, str) else getattr(f, "name", str(f)) | |
| p = Path(fp) | |
| print(f"[DEBUG] Processing: {p} (exists={p.exists()}, size={p.stat().st_size if p.exists() else 'N/A'})") | |
| if p.suffix.lower() != ".txt": | |
| return f"Only .txt files are supported. Unsupported file: **{p.name}**" | |
| all_chunks.extend(load_text(p)) | |
| if not all_chunks: | |
| return f" No text extracted from {len(files)} file(s). Check file content." | |
| corpus = [_tokenize(c.text) for c in all_chunks] | |
| CURRENT_INDEX = BM25Okapi(corpus) | |
| CURRENT_CHUNKS = all_chunks | |
| preview = "\n".join(f" • Chunk {i+1}: {c.text[:80]}…" for i, c in enumerate(all_chunks[:3])) | |
| return f" Indexed **{len(all_chunks)} chunks** from {len(files)} file(s).\n\n{preview}" | |
| # 4. Agentic Pipeline with Reasoning Trace | |
| def agentic_answer(question: str): | |
| if not HF_TOKEN: | |
| return " Missing HF_TOKEN! Add it in Space Settings → Secrets.", "", "", "", "" | |
| if CURRENT_INDEX is None: | |
| return "Upload document(s) first!", "", "", "", "" | |
| # Step 1: Retrieval (BM25) | |
| q_tokens = _tokenize(question) | |
| scores = CURRENT_INDEX.get_scores(q_tokens) | |
| top_indices = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)[:5] | |
| retrieved = [CURRENT_CHUNKS[i] for i, s in top_indices if s > 0] | |
| print(f"[DEBUG] Query tokens: {q_tokens[:10]}...") | |
| print(f"[DEBUG] Top scores: {[(i, f'{s:.2f}') for i, s in top_indices[:5]]}") | |
| print(f"[DEBUG] Retrieved {len(retrieved)} chunks") | |
| if not retrieved: | |
| return "No relevant info found. Try rephrasing your question.", "", "", "", "" | |
| context = "\n".join(f"[{i+1}] {c.text}" for i, c in enumerate(retrieved)) | |
| # Step 2: Generate Answer with Reasoning Trace | |
| messages = [ | |
| {"role": "system", "content": ( | |
| "You are a HIGH-TRUST health information assistant. " | |
| "Before providing your final answer, you MUST provide a <reasoning_trace> section. " | |
| "In this section:\n" | |
| "1. Analyze the key symptoms or topics mentioned in the user query.\n" | |
| "2. Map these symptoms/topics to the provided DOCUMENT chunks with [1] [2] etc.\n" | |
| "3. Check for safety flags (emergency symptoms, requests for diagnosis/prescription).\n" | |
| "4. Summarize your confidence level based on source coverage.\n" | |
| "Then close with </reasoning_trace> and provide your 'Final Answer'.\n\n" | |
| "RULES FOR THE FINAL ANSWER:\n" | |
| "- NO personal diagnosis. NO prescribing medication.\n" | |
| "- Use language like 'Sources suggest...' or 'According to [1]...'.\n" | |
| "- MANDATORY: Include [1] [2] etc. after every factual claim.\n" | |
| "- If the question sounds like a medical emergency, advise calling emergency services." | |
| )}, | |
| {"role": "user", "content": f"SOURCES:\n{context}\n\nQUESTION: {question}\n\nPlease analyze this step-by-step."} | |
| ] | |
| # Try primary → fallback | |
| answer_text, model_used = "", "" | |
| for model_id, label in [(PRIMARY_MODEL, "Gemma 4-31B"), (FALLBACK_MODEL, "Gemma 2-2B")]: | |
| try: | |
| resp = client.chat_completion(model=model_id, messages=messages, max_tokens=700, temperature=0.1) | |
| answer_text = resp.choices[0].message.content | |
| model_used = label | |
| break | |
| except Exception as e: | |
| print(f"[DEBUG] {label} failed: {e}") | |
| continue | |
| if not answer_text: | |
| return " Both models unavailable. Check your HF_TOKEN and model access.", "", "", "", "" | |
| # Step 3: Split Reasoning Trace from Final Answer | |
| if "<reasoning_trace>" in answer_text: | |
| parts = answer_text.split("</reasoning_trace>") | |
| trace = parts[0].replace("<reasoning_trace>", "").strip() | |
| final_ans = parts[1].strip() if len(parts) > 1 else "Analysis complete." | |
| else: | |
| trace = " *Standard inference performed — no explicit trace returned by model.*" | |
| final_ans = answer_text | |
| # Step 4: Compute Trust & Safety Badges | |
| score = grounding_score(final_ans, retrieved) | |
| is_refusal = any(kw in final_ans.lower() for kw in ["not a doctor", "professional", "emergency", "consult"]) | |
| trust_cls = "high-trust" if score > 0.5 else "low-trust" | |
| safety_cls = "safety-refusal" if is_refusal else "safety-pass" | |
| safety_label = " Safety Advisory" if is_refusal else " Safety: PASSED" | |
| trust_html = f'<div class="badge {trust_cls}">Trust: {score:.2f} ({model_used})</div>' | |
| safety_html = f'<div class="badge {safety_cls}">{safety_label}</div>' | |
| sources_md = "\n---\n".join(f"**[{i+1}] {c.source}**\n> {c.text}" for i, c in enumerate(retrieved)) | |
| return final_ans, trace, sources_md, trust_html, safety_html | |
| # 5. UI | |
| CUSTOM_CSS = """ | |
| :root { --p: #2c5282; --bg: #0f1117; --accent: #805ad5; } | |
| .gradio-container { background: var(--bg) !important; color: #e2e8f0 !important; font-family: system-ui, sans-serif; } | |
| #header { | |
| background: linear-gradient(135deg, var(--p) 0%, var(--accent) 100%); | |
| color: white; padding: 28px 24px; border-radius: 14px; | |
| margin-bottom: 24px; text-align: center; | |
| box-shadow: 0 6px 24px rgba(44,82,130,0.4); | |
| } | |
| #header h1, #header h3, #header * { color: #fff !important; } | |
| .gradio-container *, .gradio-container label, | |
| .gradio-container h1, .gradio-container h2, .gradio-container h3, | |
| .gradio-container li, .gradio-container td, .gradio-container p, | |
| .gradio-container span { color: #e2e8f0 !important; } | |
| #ans *, #src * { color: #e2e8f0 !important; } | |
| blockquote { border-left: 3px solid var(--accent); padding-left: 12px; color: #cbd5e0 !important; } | |
| .badge { | |
| padding: 6px 16px; border-radius: 20px; font-weight: 700; | |
| display: inline-block; margin-right: 8px; margin-bottom: 8px; | |
| font-size: 0.85rem; letter-spacing: 0.02em; | |
| } | |
| .high-trust, .safety-pass { background: #22543d !important; color: #c6f6d5 !important; } | |
| .low-trust, .safety-refusal { background: #742a2a !important; color: #fed7d7 !important; } | |
| #trace-accordion { border-left: 3px solid var(--accent); background: #1a1f2e !important; | |
| border-radius: 10px; margin: 10px 0; } | |
| #trace-accordion * { color: #e2e8f0 !important; } | |
| textarea, input[type="text"] { background: #1a202c !important; color: #e2e8f0 !important; | |
| border: 1px solid #4a5568 !important; border-radius: 8px !important; } | |
| .panel, .form { background: #1a202c !important; border-color: #2d3748 !important; } | |
| button.primary { background: linear-gradient(135deg, #2c5282, #805ad5) !important; | |
| border: none !important; font-weight: 700 !important; border-radius: 8px !important; } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, title="Medical RAG Studio", theme=gr.themes.Soft()) as demo: | |
| with gr.Column(elem_id="header"): | |
| gr.Markdown("# Medical RAG Studio\n### High-Trust Agentic Intelligence for Health Discovery") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| files = gr.File(label="Upload Knowledge Base (.txt)", file_count="multiple", file_types=[".txt"]) | |
| status = gr.Markdown("No documents indexed.") | |
| q = gr.Textbox(label="Research Question", lines=3, | |
| placeholder="e.g., I feel nauseous with stomach cramps after eating...") | |
| btn = gr.Button("Analyze with Gemma", variant="primary") | |
| # Demo Questions for Judges | |
| gr.Markdown("### Try These Demo Queries") | |
| gr.Examples( | |
| examples=[ | |
| ["What are the symptoms of a common cold according to the documents?"], | |
| ["What medication should I take for my headache?"], | |
| ["I have crushing chest pain spreading to my jaw and trouble breathing"], | |
| ["I ate dinner four hours ago and now I feel nauseous with stomach cramps and vomiting"], | |
| ["My eyes are itchy and watery every spring, what could this be?"], | |
| ], | |
| inputs=[q], | |
| label="", | |
| examples_per_page=5, | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| t_badge = gr.HTML('<div class="badge low-trust">Trust: 0.00</div>') | |
| s_badge = gr.HTML('<div class="badge safety-pass">Safety: Idle</div>') | |
| ans = gr.Markdown(elem_id="ans") | |
| # Reasoning Trace Accordion Explainable AI for judges | |
| with gr.Accordion(" View Gemma's Reasoning Trace", open=False, elem_id="trace-accordion"): | |
| trace_output = gr.Markdown( | |
| value="*Reasoning trace will appear here after analysis...*" | |
| ) | |
| src = gr.Markdown(label="Evidence Dashboard", elem_id="src") | |
| files.change(process_uploads, [files], [status]) | |
| btn.click(agentic_answer, [q], [ans, trace_output, src, t_badge, s_badge]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |