MonaHamid commited on
Commit
d06f60f
·
verified ·
1 Parent(s): b182e6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -2,17 +2,25 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
  import gradio as gr
4
 
 
5
  model_dir = "saved_model"
6
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_dir)
8
 
 
 
 
 
 
 
 
 
9
  def classify(text):
10
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
11
  outputs = model(**inputs)
12
  probs = torch.softmax(outputs.logits, dim=1)
13
- labels = ["toxic", "non-toxic"] # <-- corrected label order
14
  return {labels[i]: float(probs[0][i]) for i in range(len(labels))}
15
 
 
16
  gr.Interface(fn=classify, inputs="text", outputs="label").launch()
17
 
18
-
 
2
  import torch
3
  import gradio as gr
4
 
5
+ # Load tokenizer and model from your saved folder
6
  model_dir = "saved_model"
7
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
 
8
 
9
+ # Load the model with explicit label mappings
10
+ model = AutoModelForSequenceClassification.from_pretrained(
11
+ model_dir,
12
+ id2label={0: "non-toxic", 1: "toxic"},
13
+ label2id={"non-toxic": 0, "toxic": 1}
14
+ )
15
+
16
+ # Define classification function
17
  def classify(text):
18
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
  outputs = model(**inputs)
20
  probs = torch.softmax(outputs.logits, dim=1)
21
+ labels = ["non-toxic", "toxic"] # must match id2label order
22
  return {labels[i]: float(probs[0][i]) for i in range(len(labels))}
23
 
24
+ # Launch Gradio app
25
  gr.Interface(fn=classify, inputs="text", outputs="label").launch()
26