MMuzamilAI commited on
Commit
a7a0c7a
·
verified ·
1 Parent(s): c316753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
- import joblib # Assuming label_encoder is saved as a .pkl file
5
 
6
  # Load model and tokenizer
7
  model_name = "mmuzamilai/distilbert-review-bug-classifier"
@@ -10,17 +9,17 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
 
13
- # Load label encoder
14
- label_encoder = joblib.load("label_encoder.pkl") # Adjust if you have another format
15
 
16
  # Classification function
17
  def classify_review(text):
18
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
19
  with torch.no_grad():
20
  outputs = model(**inputs)
21
- predicted_label = torch.argmax(outputs.logits).item()
22
- decoded_label = label_encoder.inverse_transform([predicted_label])[0]
23
- return decoded_label
24
 
25
  # Gradio interface
26
  iface = gr.Interface(
 
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
  model_name = "mmuzamilai/distilbert-review-bug-classifier"
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
+ # Access label mapping directly from model config
13
+ id2label = model.config.id2label
14
 
15
  # Classification function
16
  def classify_review(text):
17
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
+ predicted_label_id = torch.argmax(outputs.logits).item()
21
+ predicted_label = id2label[predicted_label_id]
22
+ return predicted_label
23
 
24
  # Gradio interface
25
  iface = gr.Interface(