atharvasc27112001 commited on
Commit
eb39256
·
verified ·
1 Parent(s): 96bb33a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, torch, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from focus_area_model import LabelEmbCls
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ MODEL = "mihir-s/medquad_classify"
7
+
8
+ # Load tokenizer + base model
9
+ tok = AutoTokenizer.from_pretrained(MODEL)
10
+ base = AutoModel.from_pretrained(MODEL).eval()
11
+
12
+ # Load label data + weights
13
+ id2label = json.load(open("id2label.json"))
14
+ label_embs = torch.load("label_embs.pt", map_location="cpu")
15
+
16
+ # Load custom head
17
+ model = LabelEmbCls(base, label_embs)
18
+ model_path = hf_hub_download(repo_id=MODEL, filename="pytorch_model.bin")
19
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
20
+ model.eval()
21
+
22
+ # ✅ One-box input logic
23
+ def predict(text):
24
+ inputs = tok(text.strip(), return_tensors="pt", truncation=True, max_length=256, padding="max_length")
25
+ with torch.no_grad():
26
+ logits = model(**inputs)
27
+ return id2label[str(logits.argmax(1).item())]
28
+
29
+ # Gradio interface with 1 text box
30
+ gr.Interface(
31
+ fn=predict,
32
+ inputs=gr.Textbox(label="Enter your medical text for classification"),
33
+ outputs=gr.Textbox(label="Predicted Focus Area"),
34
+ title="🧠 MedQuad Focus-Area Classifier"
35
+ ).launch()