|
|
|
|
|
|
|
|
import os |
|
|
import traceback |
|
|
import pandas as pd |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import ( |
|
|
logging, |
|
|
AutoProcessor, |
|
|
AutoTokenizer, |
|
|
AutoModelForImageTextToText |
|
|
) |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings → Secrets") |
|
|
MODEL_ID = "google/gemma-3n-e2b-it" |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, trust_remote_code=True, token=HF_TOKEN |
|
|
) |
|
|
|
|
|
|
|
|
def generate_and_export(): |
|
|
try: |
|
|
|
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
token=HF_TOKEN, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
def to_soap(text: str) -> str: |
|
|
inputs = processor.apply_chat_template( |
|
|
[ |
|
|
{"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]}, |
|
|
{"role":"user", "content":[{"type":"text","text":text}]} |
|
|
], |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
).to(device) |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=400, |
|
|
do_sample=True, |
|
|
top_p=0.95, |
|
|
temperature=0.1, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
use_cache=False |
|
|
) |
|
|
prompt_len = inputs["input_ids"].shape[-1] |
|
|
return processor.batch_decode( |
|
|
out[:, prompt_len:], skip_special_tokens=True |
|
|
)[0].strip() |
|
|
|
|
|
|
|
|
docs, gts = [], [] |
|
|
for i in range(1, 21): |
|
|
doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.") |
|
|
docs.append(doc) |
|
|
gts.append(to_soap(doc)) |
|
|
if i % 5 == 0: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts}) |
|
|
train_df, test_df = train_test_split(df, test_size=5, random_state=42) |
|
|
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
|
|
|
|
|
|
train_preds = [to_soap(d) for d in train_df["doc_note"]] |
|
|
inf = train_df.reset_index(drop=True).copy() |
|
|
inf["id"] = inf.index + 1 |
|
|
inf["predicted_soap"] = train_preds |
|
|
inf[["id","ground_truth_soap","predicted_soap"]].to_csv( |
|
|
"outputs/inference.tsv", sep="\t", index=False |
|
|
) |
|
|
|
|
|
|
|
|
test_preds = [to_soap(d) for d in test_df["doc_note"]] |
|
|
pd.DataFrame({ |
|
|
"id": range(1, len(test_preds) + 1), |
|
|
"predicted_soap": test_preds |
|
|
}).to_csv("outputs/eval.csv", index=False) |
|
|
|
|
|
|
|
|
return ( |
|
|
"✅ Done with 20 notes (15 train / 5 test)!", |
|
|
"outputs/inference.tsv", |
|
|
"outputs/eval.csv" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return (f"❌ Error: {e}", None, None) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Gemma‑3n SOAP Generator 🩺") |
|
|
btn = gr.Button("Generate & Export 20 Notes") |
|
|
status = gr.Textbox(interactive=False, label="Status") |
|
|
inf_file = gr.File(label="Download inference.tsv") |
|
|
eval_file= gr.File(label="Download eval.csv") |
|
|
|
|
|
btn.click( |
|
|
fn=generate_and_export, |
|
|
inputs=None, |
|
|
outputs=[status, inf_file, eval_file] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|