| |
|
|
| import os |
| import pandas as pd |
| import gradio as gr |
| import torch |
| from transformers import ( |
| AutoProcessor, |
| AutoTokenizer, |
| AutoModelForImageTextToText |
| ) |
| from sklearn.model_selection import train_test_split |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| if not HF_TOKEN: |
| raise RuntimeError("Missing HF_TOKEN env var! Please add it in your Space settings → Secrets.") |
|
|
| MODEL_ID = "google/gemma-3n-e2b-it" |
|
|
| |
| processor = AutoProcessor.from_pretrained( |
| MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
| ) |
|
|
| def generate_all_and_split(): |
| """Called when the user clicks the button—loads full model, generates & saves files.""" |
| |
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| trust_remote_code=True, |
| token=HF_TOKEN, |
| load_in_8bit=True, |
| device_map="auto" |
| ) |
| device = next(model.parameters()).device |
|
|
| def to_soap(text: str) -> str: |
| inputs = processor.apply_chat_template( |
| [ |
| {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]}, |
| {"role":"user", "content":[{"type":"text","text":text}]} |
| ], |
| add_generation_prompt=True, |
| tokenize=True, |
| return_tensors="pt", |
| return_dict=True |
| ).to(device) |
| out = model.generate( |
| **inputs, |
| max_new_tokens=400, |
| do_sample=True, |
| top_p=0.95, |
| temperature=0.1, |
| pad_token_id=processor.tokenizer.eos_token_id |
| ) |
| prompt_len = inputs["input_ids"].shape[-1] |
| return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip() |
|
|
| |
| docs, gts = [], [] |
| for i in range(1, 101): |
| doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.") |
| docs.append(doc) |
| gts.append(to_soap(doc)) |
| if i % 20 == 0: |
| torch.cuda.empty_cache() |
|
|
| |
| df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts}) |
| train_df, test_df = train_test_split(df, test_size=0.3, random_state=42) |
|
|
| os.makedirs("outputs", exist_ok=True) |
|
|
| |
| train_preds = [to_soap(d) for d in train_df["doc_note"]] |
| inf = train_df.reset_index(drop=True).copy() |
| inf["id"] = inf.index + 1 |
| inf["predicted_soap"] = train_preds |
| inf[["id","ground_truth_soap","predicted_soap"]].to_csv( |
| "outputs/inference.tsv", sep="\t", index=False |
| ) |
|
|
| |
| test_preds = [to_soap(d) for d in test_df["doc_note"]] |
| pd.DataFrame({ |
| "id": range(1, len(test_preds)+1), |
| "predicted_soap": test_preds |
| }).to_csv("outputs/eval.csv", index=False) |
|
|
| return ( |
| "✅ Done!\n" |
| f"• outputs/inference.tsv (70 rows with id, GT, pred)\n" |
| f"• outputs/eval.csv (30 rows with id, pred)" |
| ) |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Gemma‑3n SOAP Generator 🩺") |
| btn = gr.Button("Generate & Save 100 Notes → 70/30 Split → inference & eval") |
| status = gr.Textbox(interactive=False, label="Status") |
| btn.click(fn=generate_all_and_split, inputs=None, outputs=status) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|