tiya1012 commited on
Commit
1bdac09
·
verified ·
1 Parent(s): f0b590f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -6,12 +6,11 @@ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  processor = AutoImageProcessor.from_pretrained("tiya1012/vit-accident-image")
7
  model = AutoModelForImageClassification.from_pretrained("tiya1012/vit-accident-image")
8
 
9
- # Get class labels from model configuration (if available)
10
- if model.config.id2label:
11
- class_labels = [model.config.id2label[i] for i in range(len(model.config.id2label))]
12
- else:
13
- # Fallback labels if config is missing (adjust as necessary)
14
- class_labels = ["No Accident", "Accident Detected"]
15
 
16
  # Define the classification function
17
  def classify_accident_image(image):
@@ -30,14 +29,17 @@ def classify_accident_image(image):
30
  # Debug: Print logits for analysis
31
  print("Logits:", logits)
32
 
33
- # Get the predicted class index
34
  probabilities = torch.softmax(logits, dim=1)[0] # Softmax to get probabilities
35
  predicted_class_idx = torch.argmax(probabilities).item()
36
  print("Predicted Class Index:", predicted_class_idx)
37
  print("Probabilities:", probabilities)
38
 
39
- # Get the predicted class label and confidence
40
- predicted_label = class_labels[predicted_class_idx]
 
 
 
41
  confidence = probabilities[predicted_class_idx].item() * 100
42
 
43
  # Format the result
 
6
  processor = AutoImageProcessor.from_pretrained("tiya1012/vit-accident-image")
7
  model = AutoModelForImageClassification.from_pretrained("tiya1012/vit-accident-image")
8
 
9
+ # Define a label mapping for `LABEL_0` and `LABEL_1`
10
+ label_mapping = {
11
+ "LABEL_0": "No Accident",
12
+ "LABEL_1": "Accident Detected"
13
+ }
 
14
 
15
  # Define the classification function
16
  def classify_accident_image(image):
 
29
  # Debug: Print logits for analysis
30
  print("Logits:", logits)
31
 
32
+ # Get the predicted class index and label
33
  probabilities = torch.softmax(logits, dim=1)[0] # Softmax to get probabilities
34
  predicted_class_idx = torch.argmax(probabilities).item()
35
  print("Predicted Class Index:", predicted_class_idx)
36
  print("Probabilities:", probabilities)
37
 
38
+ # Map the model's label to human-readable label using label_mapping
39
+ predicted_label_key = model.config.id2label[predicted_class_idx]
40
+ predicted_label = label_mapping.get(predicted_label_key, "Unknown")
41
+
42
+ # Get the confidence score
43
  confidence = probabilities[predicted_class_idx].item() * 100
44
 
45
  # Format the result