MMuzamilAI commited on
Commit
675ccd3
·
verified ·
1 Parent(s): a7a0c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -9,17 +9,26 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
- # Access label mapping directly from model config
13
  id2label = model.config.id2label
14
 
 
 
 
 
 
 
 
 
15
  # Classification function
16
  def classify_review(text):
17
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
  predicted_label_id = torch.argmax(outputs.logits).item()
21
- predicted_label = id2label[predicted_label_id]
22
- return predicted_label
 
23
 
24
  # Gradio interface
25
  iface = gr.Interface(
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
+ # Model's id2label from config
13
  id2label = model.config.id2label
14
 
15
+ # Custom label mapping
16
+ label_map = {
17
+ "label_0": "Graphical Issue",
18
+ "label_1": "Network Issue",
19
+ "label_2": "No Bug ✅",
20
+ "label_3": "Performance Issue"
21
+ }
22
+
23
  # Classification function
24
  def classify_review(text):
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  predicted_label_id = torch.argmax(outputs.logits).item()
29
+ hf_label = id2label[predicted_label_id]
30
+ custom_label = label_map.get(hf_label, "Unknown")
31
+ return custom_label
32
 
33
  # Gradio interface
34
  iface = gr.Interface(