File size: 3,608 Bytes
38df248
 
 
 
17fa973
 
38df248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fa973
38df248
 
 
 
 
17fa973
38df248
 
 
 
 
17fa973
38df248
 
 
 
 
17fa973
38df248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fa973
38df248
 
 
 
 
17fa973
38df248
 
 
 
17fa973
38df248
 
 
 
17fa973
38df248
 
 
 
17fa973
38df248
d499f78
38df248
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import re
from functools import lru_cache

import gradio as gr

# --- Space-friendly settings ---
# Keep caches persistent if you enabled Space storage
os.environ.setdefault("HF_HOME", "/data/huggingface")
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "false")

# ---------- Lazy model loaders (no heavy work at import time) ----------
@lru_cache(maxsize=1)
def get_embedder():
    # lazy import to avoid blocking the frontend
    from sentence_transformers import SentenceTransformer
    return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

@lru_cache(maxsize=1)
def get_ref_emb():
    from sentence_transformers import util  # imported where needed
    embedder = get_embedder()
    reference_text = "Patient diagnosed with anaemia and low haemoglobin."
    return embedder.encode(reference_text, convert_to_tensor=True)

@lru_cache(maxsize=1)
def get_llm():
    from transformers import pipeline
    # keep it light; GPT-2 works on CPU
    return pipeline("text-generation", model="gpt2", max_new_tokens=80)

# ------------------------ Functions per approach ------------------------
def rule_based(text: str) -> str:
    if re.search(r"\banae?mia\b", text, flags=re.IGNORECASE):
        return "Anaemia detected (keyword match)"
    return "Anaemia not detected (no keyword match)"

def ml_based(text: str) -> str:
    # load lazily (first call)
    from sentence_transformers import util
    embedder = get_embedder()
    ref_emb = get_ref_emb()
    emb = embedder.encode(text, convert_to_tensor=True)
    score = float(util.cos_sim(emb, ref_emb).item())
    verdict = "Anaemia likely" if score > 0.45 else "Anaemia unlikely"
    return f"Similarity: {score:.2f}\n{verdict}"

def ai_based(text: str) -> str:
    llm = get_llm()
    prompt = (
        "Determine if the following clinical note suggests anaemia. "
        "Answer clearly in one short paragraph with reasoning.\n\n"
        f"{text}\n"
    )
    out = llm(prompt)[0]["generated_text"]
    return out

# ----------------------------- UI -----------------------------
def build_ui():
    with gr.Blocks(title="RB vs ML vs AI — Anaemia Demo", theme=gr.themes.Soft()) as demo:
        gr.Markdown("## Clinical Text Understanding — Three Approaches")
        with gr.Row():
            default_note = "The patient presents with fatigue and very low haemoglobin."
            input_box = gr.Textbox(value=default_note, label="Clinical note", lines=4)

        with gr.Tabs():
            with gr.Tab("Rule-Based"):
                rb_btn = gr.Button("Run Rule-Based")
                rb_out = gr.Textbox(label="Result", lines=3)
                rb_btn.click(fn=rule_based, inputs=input_box, outputs=rb_out)

            with gr.Tab("Machine Learning (Embeddings)"):
                ml_btn = gr.Button("Run ML")
                ml_out = gr.Textbox(label="Result", lines=4)
                ml_btn.click(fn=ml_based, inputs=input_box, outputs=ml_out)

            with gr.Tab("AI / Foundation Model"):
                ai_btn = gr.Button("Run AI")
                ai_out = gr.Textbox(label="Result", lines=8)
                ai_btn.click(fn=ai_based, inputs=input_box, outputs=ai_out)

        gr.Markdown(
            "Notes: models are loaded lazily on first run to keep the UI responsive in Spaces."
        )
    return demo

demo = build_ui()

# Queue + launch: SSR off is important for Spaces that show a blank/broken view.
demo.queue(concurrency_count=1, max_size=10).launch(
    server_name="0.0.0.0",
    server_port=7860,
    ssr_mode=False,
    show_error=True,
    inbrowser=False,
    share=False,
)