TangYiJay commited on
Commit
0d60bc4
·
verified ·
1 Parent(s): dd30521
Files changed (1) hide show
  1. app.py +61 -4
app.py CHANGED
@@ -1,5 +1,62 @@
1
- # Use a pipeline as a high-level helper
2
- from transformers import pipeline
 
 
 
3
 
4
- pipe = pipeline("image-classification", model="yangy50/garbage-classification")
5
- pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
4
+ from PIL import Image
5
+ import torch
6
 
7
+ # ---------------- 配置模型 ----------------
8
+ MODEL_MAIN = "prithivMLmods/Trash-Net" # 主模型
9
+ MODEL_SECOND = "yangy50/garbage-classification" # 验证模型
10
+
11
+ # ---------------- 加载模型 ----------------
12
+ print("🔹 正在加载主模型...")
13
+ processor1 = AutoImageProcessor.from_pretrained(MODEL_MAIN)
14
+ model1 = AutoModelForImageClassification.from_pretrained(MODEL_MAIN)
15
+ model1.eval()
16
+ print("✅ 主模型加载完成")
17
+
18
+ print("🔹 正在加载验证模型...")
19
+ processor2 = AutoImageProcessor.from_pretrained(MODEL_SECOND)
20
+ model2 = AutoModelForImageClassification.from_pretrained(MODEL_SECOND)
21
+ model2.eval()
22
+ print("✅ 验证模型加载完成")
23
+
24
+ # ---------------- 推理函数 ----------------
25
+ def classify_image(image: Image.Image):
26
+ # 主模型预测
27
+ inputs1 = processor1(images=image, return_tensors="pt")
28
+ with torch.no_grad():
29
+ outputs1 = model1(**inputs1)
30
+ pred1 = outputs1.logits.argmax(-1).item()
31
+ label1 = model1.config.id2label[pred1]
32
+
33
+ # 验证模型预测
34
+ inputs2 = processor2(images=image, return_tensors="pt")
35
+ with torch.no_grad():
36
+ outputs2 = model2(**inputs2)
37
+ pred2 = outputs2.logits.argmax(-1).item()
38
+ label2 = model2.config.id2label[pred2]
39
+
40
+ # 双重验证
41
+ if label1 == label2:
42
+ result = f"✅ 双重验证一致,最终判定为:{label1.upper()}"
43
+ else:
44
+ result = f"⚠️ 双重验证不一致!\n主模型:{label1}\n验证模型:{label2}"
45
+
46
+ return label1, label2, result
47
+
48
+ # ---------------- Gradio 界面 ----------------
49
+ iface = gr.Interface(
50
+ fn=classify_image,
51
+ inputs=gr.Image(type="pil", label="上传图片"),
52
+ outputs=[
53
+ gr.Textbox(label="主模型结果"),
54
+ gr.Textbox(label="验证模型结果"),
55
+ gr.Textbox(label="双重验证判定")
56
+ ],
57
+ title="垃圾分类双模型检测",
58
+ description="使用 Trash-Net 和 Garbage-ViT 两个模型进行垃圾分类,结果一致才判定最终类别。"
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ iface.launch()