TangYiJay commited on
Commit
46f7c46
·
verified ·
1 Parent(s): c2ca069
Files changed (1) hide show
  1. app.py +23 -45
app.py CHANGED
@@ -1,64 +1,42 @@
1
- # app.py
2
  import gradio as gr
3
  from transformers import AutoModelForImageClassification, AutoImageProcessor
4
  from PIL import Image
5
  import torch
6
 
7
- # ---------------- 配置模型列表 ----------------
8
- MODEL_LIST = [
9
- "prithivMLmods/Trash-Net",
10
  "yangy50/garbage-classification",
11
- "ahmzakif/TrashNet-Classification"
 
12
  ]
13
 
14
- # ---------------- 加载模型 ----------------
15
- models = []
16
  processors = []
17
- loaded_model_names = []
18
- print("🔹 正在加载模型,请稍等...")
19
-
20
- for model_name in MODEL_LIST:
21
- try:
22
- processor = AutoImageProcessor.from_pretrained(model_name)
23
- model = AutoModelForImageClassification.from_pretrained(model_name)
24
- model.eval()
25
- processors.append(processor)
26
- models.append(model)
27
- loaded_model_names.append(model_name)
28
- print(f"✅ 加载成功: {model_name}")
29
- except Exception as e:
30
- print(f"❌ 加载失败: {model_name}, 错误: {e}")
31
 
32
- # ---------------- 推理函数 ----------------
33
  def classify_image(image: Image.Image):
34
  results = {}
35
- for model_name, processor, model in zip(loaded_model_names, processors, models):
36
- try:
37
- inputs = processor(images=image, return_tensors="pt")
38
- with torch.no_grad():
39
- outputs = model(**inputs)
40
- pred = outputs.logits.argmax(-1).item()
41
- if hasattr(model.config, "id2label"):
42
- label = model.config.id2label[pred]
43
- else:
44
- label = f"⚠️ 无内置 id2label,索引预测: {pred}"
45
- results[model_name] = label
46
- except Exception as e:
47
- results[model_name] = f"❌ 预测失败: {e}"
48
-
49
- results_text = "\n".join([f"{name}: {label}" for name, label in results.items()])
50
- return results_text
51
 
52
- # ---------------- Gradio 界面 ----------------
53
  iface = gr.Interface(
54
  fn=classify_image,
55
  inputs=gr.Image(type="pil", label="上传图片"),
56
- outputs=[gr.Textbox(label="所有模型预测结果")],
57
- title="垃圾分类全模型检测",
58
- description="上传图片后,每个模型独立输出预测结果,不做任何人工干预。",
59
- allow_flagging="never"
60
  )
61
 
62
- # ✅ 启用 API 模式(Space 可被外部调用)
63
  if __name__ == "__main__":
64
- iface.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoImageProcessor
3
  from PIL import Image
4
  import torch
5
 
6
+ # 三个模型
7
+ MODELS = [
 
8
  "yangy50/garbage-classification",
9
+ "ahmzakif/TrashNet-Classification",
10
+ "harriskr14/trashnet-vit"
11
  ]
12
 
 
 
13
  processors = []
14
+ models = []
15
+ for name in MODELS:
16
+ p = AutoImageProcessor.from_pretrained(name)
17
+ m = AutoModelForImageClassification.from_pretrained(name)
18
+ m.eval()
19
+ processors.append(p)
20
+ models.append(m)
 
 
 
 
 
 
 
21
 
 
22
  def classify_image(image: Image.Image):
23
  results = {}
24
+ for name, p, m in zip(MODELS, processors, models):
25
+ inputs = p(images=image, return_tensors="pt")
26
+ with torch.no_grad():
27
+ outputs = m(**inputs)
28
+ pred = outputs.logits.argmax(-1).item()
29
+ label = m.config.id2label.get(pred, f"id_{pred}")
30
+ results[name] = label
31
+ return "\n".join([f"{k}: {v}" for k, v in results.items()])
 
 
 
 
 
 
 
 
32
 
 
33
  iface = gr.Interface(
34
  fn=classify_image,
35
  inputs=gr.Image(type="pil", label="上传图片"),
36
+ outputs="text",
37
+ title="三模型垃圾分类",
38
+ description="使用三个模型独立预测垃圾种类"
 
39
  )
40
 
 
41
  if __name__ == "__main__":
42
+ iface.launch()