ogflash commited on
Commit
c3bea6e
·
verified ·
1 Parent(s): 3dbfb13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -39
app.py CHANGED
@@ -1,44 +1,42 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
 
 
4
 
5
- # Load model & tokenizer from HF or local path
6
- model_name = "ogflash/yelp_review_classifier" # Change if needed
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
- # Fix for DistilBERT models that don't accept token_type_ids
11
  def classify(text):
12
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
13
- # Remove token_type_ids if not supported
14
- if "token_type_ids" in inputs and "token_type_ids" not in model.forward.__code__.co_varnames:
15
- del inputs["token_type_ids"]
16
-
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
20
- top_class = torch.argmax(probs, dim=1).item()
21
- confidence = probs[0][top_class].item() * 100
22
-
23
- # Reliable label mapping
24
- id2label = model.config.id2label
25
- if not id2label or not isinstance(id2label, dict) or len(id2label) == 0:
26
- id2label = {
27
- 0: "Negative",
28
- 1: "Neutral",
29
- 2: "Positive"
30
- }
31
-
32
- label_name = id2label.get(top_class, f"LABEL_{top_class}")
33
- return f"{label_name} ({confidence:.2f}%)"
34
-
35
- # UI with Gradio
36
- iface = gr.Interface(
37
- fn=classify,
38
- inputs=gr.Textbox(lines=3, placeholder="Enter text to analyze..."),
39
- outputs="text",
40
- title="Sentiment Classifier",
41
- description="Predicts sentiment using a BERT-based model.",
42
- )
43
-
44
- iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Load the tokenizer and model from local path (or HF if internet is available)
6
+ model = AutoModelForSequenceClassification.from_pretrained("ogflash/yelp_review_classifier")
7
+ tokenizer = AutoTokenizer.from_pretrained("ogflash/yelp_review_classifier")
 
8
 
9
+ # Prediction function
10
  def classify(text):
11
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
12
+
13
+ # Remove token_type_ids if using DistilBERT
14
+ if "token_type_ids" in inputs:
15
+ inputs.pop("token_type_ids")
16
+
17
+ outputs = model(**inputs)
18
+ logits = outputs.logits
19
+ predicted_class_id = torch.argmax(logits, dim=1).item()
20
+ score = torch.softmax(logits, dim=1)[0][predicted_class_id].item()
21
+
22
+ # Map labels using if-elif-else
23
+ label = f"LABEL_{predicted_class_id}"
24
+ if label == "LABEL_0":
25
+ label_name = "Negative"
26
+ elif label == "LABEL_1":
27
+ label_name = "Neutral"
28
+ elif label == "LABEL_2":
29
+ label_name = "Positive"
30
+ else:
31
+ label_name = label # fallback
32
+
33
+ return f"{label_name} ({score * 100:.2f}%)"
34
+
35
+ # Gradio UI
36
+ iface = gr.Interface(fn=classify,
37
+ inputs=gr.Textbox(lines=2, placeholder="Enter your review here..."),
38
+ outputs="text",
39
+ title="Sentiment Classifier",
40
+ description="Classifies text into Positive, Neutral, or Negative.")
41
+
42
+ iface.launch()