notes / app.py
Bonosa2's picture
Update app.py
3314cdc verified
raw
history blame
3.72 kB
# app.py
import os
import pandas as pd
import gradio as gr
import torch
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForImageTextToText
)
from sklearn.model_selection import train_test_split
# 1) Retrieve your HF_TOKEN from environment (set in Space Settings → Secrets)
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError("Missing HF_TOKEN env var! Please add it in your Space settings → Secrets.")
MODEL_ID = "google/gemma-3n-e2b-it"
# 2) Eagerly load the small bits (processor & tokenizer) so the UI starts fast
processor = AutoProcessor.from_pretrained(
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
)
def generate_all_and_split():
"""Called when the user clicks the button—loads full model, generates & saves files."""
# a) Lazy‑load the 8‑bit quantized model (heavy)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
token=HF_TOKEN,
load_in_8bit=True,
device_map="auto"
)
device = next(model.parameters()).device
def to_soap(text: str) -> str:
inputs = processor.apply_chat_template(
[
{"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
{"role":"user", "content":[{"type":"text","text":text}]}
],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(device)
out = model.generate(
**inputs,
max_new_tokens=400,
do_sample=True,
top_p=0.95,
temperature=0.1,
pad_token_id=processor.tokenizer.eos_token_id
)
prompt_len = inputs["input_ids"].shape[-1]
return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
# b) Generate 100 doc_notes + ground_truth SOAPs
docs, gts = [], []
for i in range(1, 101):
doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
docs.append(doc)
gts.append(to_soap(doc))
if i % 20 == 0:
torch.cuda.empty_cache()
# c) Split 70/30
df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
os.makedirs("outputs", exist_ok=True)
# d) Inference on train → inference.tsv
train_preds = [to_soap(d) for d in train_df["doc_note"]]
inf = train_df.reset_index(drop=True).copy()
inf["id"] = inf.index + 1
inf["predicted_soap"] = train_preds
inf[["id","ground_truth_soap","predicted_soap"]].to_csv(
"outputs/inference.tsv", sep="\t", index=False
)
# e) Inference on test → eval.csv
test_preds = [to_soap(d) for d in test_df["doc_note"]]
pd.DataFrame({
"id": range(1, len(test_preds)+1),
"predicted_soap": test_preds
}).to_csv("outputs/eval.csv", index=False)
return (
"✅ Done!\n"
f"• outputs/inference.tsv (70 rows with id, GT, pred)\n"
f"• outputs/eval.csv (30 rows with id, pred)"
)
# 3) Gradio UI—instant startup
with gr.Blocks() as demo:
gr.Markdown("## Gemma‑3n SOAP Generator 🩺")
btn = gr.Button("Generate & Save 100 Notes → 70/30 Split → inference & eval")
status = gr.Textbox(interactive=False, label="Status")
btn.click(fn=generate_all_and_split, inputs=None, outputs=status)
if __name__ == "__main__":
demo.launch()