MMuzamilAI commited on
Commit
4afc6b6
·
verified ·
1 Parent(s): 675ccd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -9,15 +9,12 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
- # Model's id2label from config
13
- id2label = model.config.id2label
14
-
15
- # Custom label mapping
16
  label_map = {
17
- "label_0": "Graphical Issue",
18
- "label_1": "Network Issue",
19
- "label_2": "No Bug ✅",
20
- "label_3": "Performance Issue"
21
  }
22
 
23
  # Classification function
@@ -26,9 +23,7 @@ def classify_review(text):
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  predicted_label_id = torch.argmax(outputs.logits).item()
29
- hf_label = id2label[predicted_label_id]
30
- custom_label = label_map.get(hf_label, "Unknown")
31
- return custom_label
32
 
33
  # Gradio interface
34
  iface = gr.Interface(
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
+ # Custom label mapping using integer keys
 
 
 
13
  label_map = {
14
+ 0: "graphi",
15
+ 1: "netw",
16
+ 2: "no",
17
+ 3: "perf"
18
  }
19
 
20
  # Classification function
 
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
  predicted_label_id = torch.argmax(outputs.logits).item()
26
+ return label_map.get(predicted_label_id, "Unknown")
 
 
27
 
28
  # Gradio interface
29
  iface = gr.Interface(