ogflash commited on
Commit
e2e4cf3
·
verified ·
1 Parent(s): 36792d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -1,38 +1,49 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
 
4
 
5
  # Load model and tokenizer
6
- model_name = "ogflash/yelp_review_classifier"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
-
10
- # Inference function
11
- def classify(text):
12
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
13
 
14
- # Remove token_type_ids for DistilBERT compatibility
15
- if "token_type_ids" in inputs:
16
- del inputs["token_type_ids"]
 
 
 
17
 
 
 
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
- logits = outputs.logits
21
- predicted_class = torch.argmax(logits, dim=1).item()
22
- score = torch.softmax(logits, dim=1)[0][predicted_class].item()
 
 
 
23
 
24
- label = model.config.id2label[predicted_class]
25
- return f"{label} ({round(score * 100, 2)}%)"
 
 
 
 
26
 
27
  # Gradio UI
28
  with gr.Blocks() as demo:
29
- gr.Markdown("## 📝 Yelp Review Sentiment Classifier")
 
 
30
  with gr.Row():
31
- text_input = gr.Textbox(label="Enter Review", placeholder="Type your Yelp review...", lines=4)
 
32
  with gr.Row():
33
- submit_btn = gr.Button("Classify")
34
- output = gr.Textbox(label="Prediction")
35
 
36
- submit_btn.click(fn=classify, inputs=text_input, outputs=output)
37
 
38
- demo.launch()
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import gradio as gr
4
 
5
  # Load model and tokenizer
6
+ model_path = "model" # Your local fine-tuned model directory
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
 
 
 
 
9
 
10
+ # Define label mapping
11
+ id2label = model.config.id2label or {
12
+ 0: "Negative",
13
+ 1: "Neutral",
14
+ 2: "Positive"
15
+ }
16
 
17
+ def classify(text):
18
+ inputs = tokenizer(text, return_tensors="pt")
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
+ logits = outputs.logits
22
+ probs = torch.softmax(logits, dim=1)[0]
23
+
24
+ predicted_class = torch.argmax(probs).item()
25
+ label = id2label.get(predicted_class, f"LABEL_{predicted_class}")
26
+ confidence = round(float(probs[predicted_class]) * 100, 2)
27
 
28
+ all_probs = {
29
+ id2label.get(i, f"LABEL_{i}"): f"{round(float(prob)*100, 2)}%"
30
+ for i, prob in enumerate(probs)
31
+ }
32
+
33
+ return f"Prediction: {label} ({confidence}%)", all_probs
34
 
35
  # Gradio UI
36
  with gr.Blocks() as demo:
37
+ gr.Markdown("# Yelp Review Sentiment Classifier")
38
+ with gr.Row():
39
+ input_box = gr.Textbox(lines=4, label="Enter a review")
40
  with gr.Row():
41
+ output_label = gr.Textbox(label="Predicted Sentiment")
42
+ output_probs = gr.JSON(label="All Class Probabilities")
43
  with gr.Row():
44
+ classify_btn = gr.Button("Classify")
 
45
 
46
+ classify_btn.click(fn=classify, inputs=input_box, outputs=[output_label, output_probs])
47
 
48
+ if __name__ == "__main__":
49
+ demo.launch()