TangYiJay commited on
Commit
1530376
·
verified ·
1 Parent(s): 30d9fb0
Files changed (1) hide show
  1. app.py +76 -75
app.py CHANGED
@@ -1,94 +1,95 @@
1
  import gradio as gr
2
- from PIL import Image
3
  import numpy as np
4
- import torch
5
- from transformers import (
6
- AutoModelForImageSegmentation,
7
- AutoProcessor,
8
- AutoFeatureExtractor,
9
- AutoModelForImageClassification,
10
- )
11
 
12
- # === Load SAM model for segmentation ===
13
- sam_model_id = "facebook/sam-vit-base"
14
- processor_sam = AutoProcessor.from_pretrained(sam_model_id)
15
- model_sam = AutoModelForImageSegmentation.from_pretrained(sam_model_id)
 
 
16
 
17
- # === ② Load garbage classification model ===
18
- cls_model_id = "yangy50/garbage-classification"
19
- extractor = AutoFeatureExtractor.from_pretrained(cls_model_id)
20
- cls_model = AutoModelForImageClassification.from_pretrained(cls_model_id)
21
 
22
- base_img = None # Global memory for base image
 
23
 
24
- # === Step 1: Set base ===
25
  def set_base(image):
26
- global base_img
27
- if image is None:
28
- return "Please upload an empty bin image."
29
- base_img = image.convert("RGB")
30
- return "✅ Base image saved successfully."
31
 
32
- # === Step 2: Detect and classify trash ===
33
- def detect_trash(image):
34
- global base_img
35
- if base_img is None:
36
- return "Please set a base image first."
37
-
38
- current_img = image.convert("RGB")
39
 
40
  # Convert to numpy
41
- base_np = np.array(base_img).astype(np.float32)
42
- current_np = np.array(current_img).astype(np.float32)
43
-
44
- # Difference mask
45
- diff = np.abs(current_np - base_np).mean(axis=2)
46
- mask = (diff > 40).astype(np.uint8) * 255 # threshold
47
- mask_img = Image.fromarray(mask).convert("RGB")
48
-
49
- # Use SAM to refine the mask
50
- inputs = processor_sam(images=current_img, segmentation_maps=mask_img, return_tensors="pt")
51
- with torch.no_grad():
52
- outputs = model_sam(**inputs)
53
- seg = outputs.pred_masks[0].cpu().numpy()
54
-
55
- # Crop bounding box of detected trash
56
- ys, xs = np.where(seg > 0.5)
57
- if len(xs) == 0 or len(ys) == 0:
58
- return "No significant object detected."
59
-
60
- x1, x2, y1, y2 = xs.min(), xs.max(), ys.min(), ys.max()
61
- cropped = current_img.crop((x1, y1, x2, y2))
62
-
63
- # Classify the cropped object
64
- cls_inputs = extractor(images=cropped, return_tensors="pt")
65
- with torch.no_grad():
66
- cls_out = cls_model(**cls_inputs)
67
- probs = torch.nn.functional.softmax(cls_out.logits, dim=-1)
68
- pred_idx = torch.argmax(probs, dim=-1).item()
69
- pred_class = cls_model.config.id2label[pred_idx]
70
-
71
- return f"🧩 Detected Material: {pred_class}"
72
-
73
- # === Build UI ===
 
 
 
 
 
 
 
 
 
 
74
  set_base_ui = gr.Interface(
75
  fn=set_base,
76
- inputs=gr.Image(type="pil", label="Upload Empty Bin (Base)"),
77
- outputs=gr.Textbox(label="Status"),
78
- title="🧩 Set Base",
 
79
  )
80
 
81
  detect_trash_ui = gr.Interface(
82
  fn=detect_trash,
83
  inputs=gr.Image(type="pil", label="Upload Trash Image"),
84
- outputs=gr.Textbox(label="Detection Result"),
85
- title="♻️ Detect & Classify Trash",
86
- )
87
-
88
- demo = gr.TabbedInterface(
89
- [set_base_ui, detect_trash_ui],
90
- ["Set Base", "Detect Trash"]
91
  )
92
 
93
- if __name__ == "__main__":
94
- demo.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ from PIL import Image
4
+ from segment_anything import sam_model_registry, SamPredictor
5
+ from transformers import BlipProcessor, BlipForQuestionAnswering
 
 
 
 
6
 
7
+ # ===== 1️⃣ Load models =====
8
+ # SAM
9
+ sam_checkpoint = "sam_vit_b_01ec64.pth" # 上传到Space的checkpoint
10
+ sam_model_type = "vit_b"
11
+ sam_model = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
12
+ sam_predictor = SamPredictor(sam_model)
13
 
14
+ # BLIP
15
+ blip_model_name = "Salesforce/blip-vqa-base"
16
+ blip_processor = BlipProcessor.from_pretrained(blip_model_name)
17
+ blip_model = BlipForQuestionAnswering.from_pretrained(blip_model_name)
18
 
19
+ # ===== 2️⃣ Global base image =====
20
+ base_image = None
21
 
22
+ # ===== 3️⃣ Set base =====
23
  def set_base(image):
24
+ global base_image
25
+ base_image = image
26
+ return "Base image saved successfully."
 
 
27
 
28
+ # ===== 4️⃣ Detect trash =====
29
+ def detect_trash(trash_image):
30
+ global base_image
31
+ if base_image is None:
32
+ return "Please upload a base image first."
 
 
33
 
34
  # Convert to numpy
35
+ base_np = np.array(base_image.resize(trash_image.size))
36
+ trash_np = np.array(trash_image)
37
+
38
+ # Compute simple difference mask
39
+ diff = np.abs(trash_np.astype(np.int16) - base_np.astype(np.int16))
40
+ mask = (diff.sum(axis=2) > 50).astype(np.uint8) # binary mask
41
+
42
+ # Find bounding box from mask
43
+ coords = np.argwhere(mask)
44
+ if coords.size == 0:
45
+ return "No difference detected."
46
+ y0, x0 = coords.min(axis=0)
47
+ y1, x1 = coords.max(axis=0)
48
+ box = np.array([[x0, y0, x1, y1]])
49
+
50
+ # Use SAM to refine mask
51
+ sam_predictor.set_image(trash_np)
52
+ masks, scores, logits = sam_predictor.predict(boxes=box)
53
+ # Take largest mask
54
+ mask_refined = masks[0]
55
+
56
+ # Crop the masked area
57
+ ys, xs = np.where(mask_refined)
58
+ if ys.size == 0:
59
+ return "SAM did not find any object."
60
+ cropped = trash_np[ys.min():ys.max(), xs.min():xs.max()]
61
+
62
+ # Convert to PIL for BLIP
63
+ cropped_img = Image.fromarray(cropped)
64
+
65
+ # BLIP question
66
+ question = "What material is this? Choose from plastic, metal, paper, cardboard, glass, trash."
67
+ inputs = blip_processor(cropped_img, question, return_tensors="pt")
68
+ out = blip_model.generate(**inputs)
69
+ answer = blip_processor.decode(out[0], skip_special_tokens=True)
70
+
71
+ # Only allow predefined classes
72
+ valid_classes = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
73
+ result = next((c for c in valid_classes if c in answer.lower()), "trash")
74
+
75
+ return result.capitalize()
76
+
77
+ # ===== 5️⃣ Gradio UI =====
78
  set_base_ui = gr.Interface(
79
  fn=set_base,
80
+ inputs=gr.Image(type="pil", label="Upload Base Image"),
81
+ outputs=gr.Textbox(label="Result"),
82
+ title="Set Base Image",
83
+ api_name="/set_base"
84
  )
85
 
86
  detect_trash_ui = gr.Interface(
87
  fn=detect_trash,
88
  inputs=gr.Image(type="pil", label="Upload Trash Image"),
89
+ outputs=gr.Textbox(label="Detected Material"),
90
+ title="Detect Trash Material",
91
+ api_name="/detect_trash"
 
 
 
 
92
  )
93
 
94
+ demo = gr.TabbedInterface([set_base_ui, detect_trash_ui], ["Set Base", "Detect Trash"])
95
+ demo.launch()