TangYiJay commited on
Commit
91c176f
·
verified ·
1 Parent(s): 606734b
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -5,19 +5,23 @@ from PIL import Image
5
  import torch
6
  from collections import Counter
7
 
8
- # ---------------- 模型列表 ----------------
9
  MODEL_LIST = [
10
  "prithivMLmods/Trash-Net",
11
  "yangy50/garbage-classification",
12
  "eunoiawiira-vgg-realwaste-classification"
13
  ]
14
 
15
- # ---------------- 加载模型 ----------------
 
 
 
 
 
16
  models = []
17
  processors = []
18
  devices = []
19
 
20
- print("🔹 正在加载模型,请稍等...")
21
 
22
  for model_name in MODEL_LIST:
23
  try:
@@ -31,11 +35,10 @@ for model_name in MODEL_LIST:
31
  processors.append(processor)
32
  models.append(model)
33
  devices.append(next(model.parameters()).device)
34
- print(f"✅ 加载完成: {model_name}")
35
  except Exception as e:
36
- print(f" 加载失败: {model_name}, 错误: {e}")
37
 
38
- # ---------------- 推理函数 ----------------
39
  def classify_image(image: Image.Image):
40
  results = {}
41
  for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
@@ -52,24 +55,33 @@ def classify_image(image: Image.Image):
52
 
53
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
54
 
55
- # 计算最终标签(投票法)
56
  valid_labels = [lbl for lbl in results.values() if not lbl.startswith("error")]
57
- final_label = Counter(valid_labels).most_common(1)[0][0] if valid_labels else "Unknown"
58
- results_text += f"\n\n最终标签: {final_label}"
 
 
 
 
 
 
 
 
 
 
 
59
  return results_text
60
 
61
- # ---------------- Gradio 界面 ----------------
62
  iface = gr.Interface(
63
  fn=classify_image,
64
- inputs=gr.Image(type="pil", label="上传图片"),
65
- outputs=[gr.Textbox(label="模型预测结果")],
66
- title="垃圾分类多模型检测",
67
  description=(
68
- "上传图片后,使用以下模型进行垃圾分类,每个模型结果单独输出:\n"
69
- "1. prithivMLmods/Trash-Net\n"
70
- "2. yangy50/garbage-classification\n"
71
- "3. eunoiawiira-vgg-realwaste-classification\n"
72
- "最终标签通过投票法生成"
73
  )
74
  )
75
 
 
5
  import torch
6
  from collections import Counter
7
 
 
8
  MODEL_LIST = [
9
  "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
  "eunoiawiira-vgg-realwaste-classification"
12
  ]
13
 
14
+ PRIORITY_ORDER = [
15
+ "yangy50/garbage-classification",
16
+ "eunoiawiira-vgg-realwaste-classification",
17
+ "prithivMLmods/Trash-Net"
18
+ ]
19
+
20
  models = []
21
  processors = []
22
  devices = []
23
 
24
+ print("Loading models...")
25
 
26
  for model_name in MODEL_LIST:
27
  try:
 
35
  processors.append(processor)
36
  models.append(model)
37
  devices.append(next(model.parameters()).device)
38
+ print(f"Loaded: {model_name}")
39
  except Exception as e:
40
+ print(f"Failed to load {model_name}, error: {e}")
41
 
 
42
  def classify_image(image: Image.Image):
43
  results = {}
44
  for model_name, processor, model, device in zip(MODEL_LIST, processors, models, devices):
 
55
 
56
  results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
57
 
 
58
  valid_labels = [lbl for lbl in results.values() if not lbl.startswith("error")]
59
+ label_counts = Counter(valid_labels)
60
+
61
+ if len(label_counts) == 0:
62
+ final_label = "Unknown"
63
+ elif len(label_counts) == 1:
64
+ final_label = valid_labels[0]
65
+ else:
66
+ for model_name in PRIORITY_ORDER:
67
+ if model_name in results and not results[model_name].startswith("error"):
68
+ final_label = results[model_name]
69
+ break
70
+
71
+ results_text += f"\n\nFinal Label: {final_label}"
72
  return results_text
73
 
 
74
  iface = gr.Interface(
75
  fn=classify_image,
76
+ inputs=gr.Image(type="pil", label="Upload Image"),
77
+ outputs=[gr.Textbox(label="Model Predictions")],
78
+ title="Multi-Model Trash Classification",
79
  description=(
80
+ "Uploads an image and classifies trash using three models.\n"
81
+ "All model predictions are displayed and the final label is selected by priority:\n"
82
+ "1. yangy50/garbage-classification\n"
83
+ "2. eunoiawiira-vgg-realwaste-classification\n"
84
+ "3. prithivMLmods/Trash-Net"
85
  )
86
  )
87