TangYiJay commited on
Commit
637a6ad
·
verified ·
1 Parent(s): 91c176f
Files changed (1) hide show
  1. app.py +12 -28
app.py CHANGED
@@ -8,13 +8,7 @@ from collections import Counter
8
  MODEL_LIST = [
9
  "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
- "eunoiawiira-vgg-realwaste-classification"
12
- ]
13
-
14
- PRIORITY_ORDER = [
15
- "yangy50/garbage-classification",
16
- "eunoiawiira-vgg-realwaste-classification",
17
- "prithivMLmods/Trash-Net"
18
  ]
19
 
20
  models = []
@@ -22,7 +16,6 @@ processors = []
22
  devices = []
23
 
24
  print("Loading models...")
25
-
26
  for model_name in MODEL_LIST:
27
  try:
28
  processor = AutoImageProcessor.from_pretrained(model_name)
@@ -37,7 +30,7 @@ for model_name in MODEL_LIST:
37
  devices.append(next(model.parameters()).device)
38
  print(f"Loaded: {model_name}")
39
  except Exception as e:
40
- print(f"Failed to load {model_name}, error: {e}")
41
 
42
  def classify_image(image: Image.Image):
43
  results = {}
@@ -51,24 +44,15 @@ def classify_image(image: Image.Image):
51
  label = model.config.id2label[pred]
52
  results[model_name] = label
53
  except Exception as e:
54
- results[model_name] = f"error: {e}"
55
 
 
56
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
57
 
 
58
  valid_labels = [lbl for lbl in results.values() if not lbl.startswith("error")]
59
- label_counts = Counter(valid_labels)
60
-
61
- if len(label_counts) == 0:
62
- final_label = "Unknown"
63
- elif len(label_counts) == 1:
64
- final_label = valid_labels[0]
65
- else:
66
- for model_name in PRIORITY_ORDER:
67
- if model_name in results and not results[model_name].startswith("error"):
68
- final_label = results[model_name]
69
- break
70
-
71
- results_text += f"\n\nFinal Label: {final_label}"
72
  return results_text
73
 
74
  iface = gr.Interface(
@@ -77,11 +61,11 @@ iface = gr.Interface(
77
  outputs=[gr.Textbox(label="Model Predictions")],
78
  title="Multi-Model Trash Classification",
79
  description=(
80
- "Uploads an image and classifies trash using three models.\n"
81
- "All model predictions are displayed and the final label is selected by priority:\n"
82
- "1. yangy50/garbage-classification\n"
83
- "2. eunoiawiira-vgg-realwaste-classification\n"
84
- "3. prithivMLmods/Trash-Net"
85
  )
86
  )
87
 
 
8
  MODEL_LIST = [
9
  "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
+ "eunoiawiira/vgg-realwaste-classification"
 
 
 
 
 
 
12
  ]
13
 
14
  models = []
 
16
  devices = []
17
 
18
  print("Loading models...")
 
19
  for model_name in MODEL_LIST:
20
  try:
21
  processor = AutoImageProcessor.from_pretrained(model_name)
 
30
  devices.append(next(model.parameters()).device)
31
  print(f"Loaded: {model_name}")
32
  except Exception as e:
33
+ print(f"Failed to load {model_name}: {e}")
34
 
35
  def classify_image(image: Image.Image):
36
  results = {}
 
44
  label = model.config.id2label[pred]
45
  results[model_name] = label
46
  except Exception as e:
47
+ results[model_name] = f"error:{e}"
48
 
49
+ # 格式化输出每个模型的结果
50
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
51
 
52
+ # 投票法计算最终标签
53
  valid_labels = [lbl for lbl in results.values() if not lbl.startswith("error")]
54
+ final_label = Counter(valid_labels).most_common(1)[0][0] if valid_labels else "Unknown"
55
+ results_text += f"\n\nFinal Label (voting): {final_label}"
 
 
 
 
 
 
 
 
 
 
 
56
  return results_text
57
 
58
  iface = gr.Interface(
 
61
  outputs=[gr.Textbox(label="Model Predictions")],
62
  title="Multi-Model Trash Classification",
63
  description=(
64
+ "Upload an image, and the following models will classify it:\n"
65
+ "1. prithivMLmods/Trash-Net\n"
66
+ "2. yangy50/garbage-classification\n"
67
+ "3. eunoiawiira/vgg-realwaste-classification\n"
68
+ "The final label is determined by majority vote."
69
  )
70
  )
71