| | import json, torch, gradio as gr |
| | from transformers import AutoTokenizer, AutoModel |
| | from focus_area_model import LabelEmbCls |
| | from huggingface_hub import hf_hub_download |
| |
|
| | MODEL = "mihir-s/medquad_classify" |
| |
|
| | |
| | tok = AutoTokenizer.from_pretrained(MODEL) |
| | base = AutoModel.from_pretrained(MODEL).eval() |
| |
|
| | |
| | id2label = json.load(open("id2label.json")) |
| | label_embs = torch.load("label_embs.pt", map_location="cpu") |
| |
|
| | |
| | model = LabelEmbCls(base, label_embs) |
| | model_path = hf_hub_download(repo_id=MODEL, filename="pytorch_model.bin") |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| | model.eval() |
| |
|
| | |
| | def predict(text): |
| | inputs = tok(text.strip(), return_tensors="pt", truncation=True, max_length=256, padding="max_length") |
| | with torch.no_grad(): |
| | logits = model(**inputs) |
| | return id2label[str(logits.argmax(1).item())] |
| |
|
| | |
| | gr.Interface( |
| | fn=predict, |
| | inputs=gr.Textbox(label="Enter your medical text for classification"), |
| | outputs=gr.Textbox(label="Predicted Focus Area"), |
| | title="🧠 MedQuad Focus-Area Classifier" |
| | ).launch() |