File size: 5,042 Bytes
c7e3aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# app.py

import os
import traceback
import pandas as pd
import torch
import gradio as gr
from transformers import (
    logging,
    AutoProcessor,
    AutoTokenizer,
    AutoModelForImageTextToText
)
from sklearn.model_selection import train_test_split

# ─── Silence irrelevant warnings ───────────────────────────────────────────────
logging.set_verbosity_error()

# ─── Configuration ────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings → Secrets")
MODEL_ID = "google/gemma-3n-e2b-it"

# ─── Fast startup: load only processor & tokenizer ─────────────────────────────
processor = AutoProcessor.from_pretrained(
    MODEL_ID, trust_remote_code=True, token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, trust_remote_code=True, token=HF_TOKEN
)

# ─── Heavy work runs on button click ───────────────────────────────────────────
def generate_and_export():
    try:
        # 1) Lazy‑load the full FP16 model
        model = AutoModelForImageTextToText.from_pretrained(
            MODEL_ID,
            trust_remote_code=True,
            token=HF_TOKEN,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        device = next(model.parameters()).device

        # 2) Text→SOAP helper
        def to_soap(text: str) -> str:
            inputs = processor.apply_chat_template(
                [
                    {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
                    {"role":"user",  "content":[{"type":"text","text":text}]}
                ],
                add_generation_prompt=True,
                tokenize=True,
                return_tensors="pt",
                return_dict=True
            ).to(device)
            out = model.generate(
                **inputs,
                max_new_tokens=400,
                do_sample=True,
                top_p=0.95,
                temperature=0.1,
                pad_token_id=processor.tokenizer.eos_token_id,
                use_cache=False
            )
            prompt_len = inputs["input_ids"].shape[-1]
            return processor.batch_decode(
                out[:, prompt_len:], skip_special_tokens=True
            )[0].strip()

        # 3) Generate 20 doc notes + ground truths
        docs, gts = [], []
        for i in range(1, 21):
            doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
            docs.append(doc)
            gts.append(to_soap(doc))
            if i % 5 == 0:
                torch.cuda.empty_cache()

        # 4) Split into 15 train / 5 test
        df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
        train_df, test_df = train_test_split(df, test_size=5, random_state=42)

        os.makedirs("outputs", exist_ok=True)

        # 5) Inference on train split → outputs/inference.tsv
        train_preds = [to_soap(d) for d in train_df["doc_note"]]
        inf = train_df.reset_index(drop=True).copy()
        inf["id"]             = inf.index + 1
        inf["predicted_soap"] = train_preds
        inf[["id","ground_truth_soap","predicted_soap"]].to_csv(
            "outputs/inference.tsv", sep="\t", index=False
        )

        # 6) Inference on test split → outputs/eval.csv
        test_preds = [to_soap(d) for d in test_df["doc_note"]]
        pd.DataFrame({
            "id":             range(1, len(test_preds) + 1),
            "predicted_soap": test_preds
        }).to_csv("outputs/eval.csv", index=False)

        # 7) Return status + file paths for download
        return (
            "✅ Done with 20 notes (15 train / 5 test)!",
            "outputs/inference.tsv",
            "outputs/eval.csv"
        )

    except Exception as e:
        traceback.print_exc()
        return (f"❌ Error: {e}", None, None)

# ─── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks() as demo:
    gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
    btn      = gr.Button("Generate & Export 20 Notes")
    status   = gr.Textbox(interactive=False, label="Status")
    inf_file = gr.File(label="Download inference.tsv")
    eval_file= gr.File(label="Download eval.csv")

    btn.click(
        fn=generate_and_export,
        inputs=None,
        outputs=[status, inf_file, eval_file]
    )

if __name__ == "__main__":
    demo.launch()