Zynaly commited on
Commit
687791d
·
verified ·
1 Parent(s): b74bf68

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ # ---------------------------------------------------
6
+ # 1. Load Model + Tokenizer
7
+ # ---------------------------------------------------
8
+ model_path = r"D:\\Assignment\\fine_tuned_deberta_model"
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)
12
+
13
+ model.eval()
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+
17
+ labels = ["Entailment", "Neutral", "Contradiction"]
18
+
19
+ # ---------------------------------------------------
20
+ # 2. Prediction Function
21
+ # ---------------------------------------------------
22
+ def predict_nli(premise, hypothesis):
23
+ if not premise.strip() or not hypothesis.strip():
24
+ return "Error: Inputs cannot be empty.", None
25
+
26
+ text = premise + " [SEP] " + hypothesis
27
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
28
+
29
+ with torch.no_grad():
30
+ logits = model(**inputs).logits
31
+ probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
32
+
33
+ pred_idx = int(probs.argmax())
34
+ pred_label = labels[pred_idx]
35
+
36
+ prob_dict = {labels[i]: f"{probs[i] * 100:.2f}%" for i in range(len(labels))}
37
+
38
+ return pred_label, prob_dict
39
+
40
+ # ---------------------------------------------------
41
+ # 3. Gradio Interface
42
+ # ---------------------------------------------------
43
+ interface = gr.Interface(
44
+ fn=predict_nli,
45
+ inputs=[
46
+ gr.Textbox(label="Premise", placeholder="Enter the premise sentence..."),
47
+ gr.Textbox(label="Hypothesis", placeholder="Enter the hypothesis sentence..."),
48
+ ],
49
+ outputs=[
50
+ gr.Label(label="Prediction"),
51
+ gr.JSON(label="Confidence Scores (%)"),
52
+ ],
53
+ title="NLI Model (DeBERTa + CCT Ready)",
54
+ description="Enter a premise and hypothesis to get Entailment, Neutral, or Contradiction with confidence scores.",
55
+ flagging_mode="never" # <-- UPDATED
56
+ )
57
+
58
+ interface.launch(server_name="0.0.0.0", server_port=7860)