ogflash commited on
Commit
3dbfb13
·
verified ·
1 Parent(s): e2e4cf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -41
app.py CHANGED
@@ -2,48 +2,43 @@ import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
4
 
5
- # Load model and tokenizer
6
- model_path = "model" # Your local fine-tuned model directory
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
9
-
10
- # Define label mapping
11
- id2label = model.config.id2label or {
12
- 0: "Negative",
13
- 1: "Neutral",
14
- 2: "Positive"
15
- }
16
 
 
17
  def classify(text):
18
- inputs = tokenizer(text, return_tensors="pt")
 
 
 
 
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
- logits = outputs.logits
22
- probs = torch.softmax(logits, dim=1)[0]
23
-
24
- predicted_class = torch.argmax(probs).item()
25
- label = id2label.get(predicted_class, f"LABEL_{predicted_class}")
26
- confidence = round(float(probs[predicted_class]) * 100, 2)
27
-
28
- all_probs = {
29
- id2label.get(i, f"LABEL_{i}"): f"{round(float(prob)*100, 2)}%"
30
- for i, prob in enumerate(probs)
31
- }
32
-
33
- return f"Prediction: {label} ({confidence}%)", all_probs
34
-
35
- # Gradio UI
36
- with gr.Blocks() as demo:
37
- gr.Markdown("# Yelp Review Sentiment Classifier")
38
- with gr.Row():
39
- input_box = gr.Textbox(lines=4, label="Enter a review")
40
- with gr.Row():
41
- output_label = gr.Textbox(label="Predicted Sentiment")
42
- output_probs = gr.JSON(label="All Class Probabilities")
43
- with gr.Row():
44
- classify_btn = gr.Button("Classify")
45
-
46
- classify_btn.click(fn=classify, inputs=input_box, outputs=[output_label, output_probs])
47
-
48
- if __name__ == "__main__":
49
- demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
4
 
5
+ # Load model & tokenizer from HF or local path
6
+ model_name = "ogflash/yelp_review_classifier" # Change if needed
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
 
 
 
 
9
 
10
+ # Fix for DistilBERT models that don't accept token_type_ids
11
  def classify(text):
12
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
13
+ # Remove token_type_ids if not supported
14
+ if "token_type_ids" in inputs and "token_type_ids" not in model.forward.__code__.co_varnames:
15
+ del inputs["token_type_ids"]
16
+
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
20
+ top_class = torch.argmax(probs, dim=1).item()
21
+ confidence = probs[0][top_class].item() * 100
22
+
23
+ # Reliable label mapping
24
+ id2label = model.config.id2label
25
+ if not id2label or not isinstance(id2label, dict) or len(id2label) == 0:
26
+ id2label = {
27
+ 0: "Negative",
28
+ 1: "Neutral",
29
+ 2: "Positive"
30
+ }
31
+
32
+ label_name = id2label.get(top_class, f"LABEL_{top_class}")
33
+ return f"{label_name} ({confidence:.2f}%)"
34
+
35
+ # UI with Gradio
36
+ iface = gr.Interface(
37
+ fn=classify,
38
+ inputs=gr.Textbox(lines=3, placeholder="Enter text to analyze..."),
39
+ outputs="text",
40
+ title="Sentiment Classifier",
41
+ description="Predicts sentiment using a BERT-based model.",
42
+ )
43
+
44
+ iface.launch(share=True)