import json, torch, gradio as gr from transformers import AutoTokenizer, AutoModel from focus_area_model import LabelEmbCls from huggingface_hub import hf_hub_download MODEL = "mihir-s/medquad_classify" # Load tokenizer + base model tok = AutoTokenizer.from_pretrained(MODEL) base = AutoModel.from_pretrained(MODEL).eval() # Load label data + weights id2label = json.load(open("id2label.json")) label_embs = torch.load("label_embs.pt", map_location="cpu") # Load custom head model = LabelEmbCls(base, label_embs) model_path = hf_hub_download(repo_id=MODEL, filename="pytorch_model.bin") model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # ✅ One-box input logic def predict(text): inputs = tok(text.strip(), return_tensors="pt", truncation=True, max_length=256, padding="max_length") with torch.no_grad(): logits = model(**inputs) return id2label[str(logits.argmax(1).item())] # Gradio interface with 1 text box gr.Interface( fn=predict, inputs=gr.Textbox(label="Enter your medical text for classification"), outputs=gr.Textbox(label="Predicted Focus Area"), title="🧠 MedQuad Focus-Area Classifier" ).launch()