TangYiJay commited on
Commit
43da16d
·
verified ·
1 Parent(s): a5f452d
Files changed (1) hide show
  1. app.py +87 -61
app.py CHANGED
@@ -1,67 +1,93 @@
1
  import gradio as gr
2
- from PIL import Image, ImageChops
3
- from transformers import BlipProcessor, BlipForQuestionAnswering
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load BLIP model
7
- model_name = "Salesforce/blip-vqa-base"
8
- processor = BlipProcessor.from_pretrained(model_name)
9
- model = BlipForQuestionAnswering.from_pretrained(model_name)
10
-
11
- # Ensure device
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model.to(device)
14
-
15
- def crop_difference(base_img: Image.Image, trash_img: Image.Image) -> Image.Image:
16
- # Convert to same mode
17
- base_img = base_img.convert("RGB")
18
- trash_img = trash_img.convert("RGB")
19
-
20
- # Compute difference
21
- diff = ImageChops.difference(trash_img, base_img)
22
- # Crop to non-zero bbox
23
- bbox = diff.getbbox()
24
- if bbox:
25
- cropped = trash_img.crop(bbox)
26
- return cropped
27
- else:
28
- return trash_img # fallback if no difference
29
-
30
- def identify_material(base_img: Image.Image, trash_img: Image.Image) -> str:
31
- if base_img is None or trash_img is None:
32
- return "Please upload both base and trash images."
33
-
34
- cropped = crop_difference(base_img, trash_img)
35
-
36
- question = "What material is this? Choose from: plastic, metal, paper, cardboard, glass, trash."
37
- inputs = processor(cropped, question, return_tensors="pt").to(device)
38
- out = model.generate(**inputs)
39
- answer = processor.decode(out[0], skip_special_tokens=True)
40
-
41
- valid_classes = ["plastic", "metal", "paper", "cardboard", "glass", "trash"]
42
- result = next((c for c in valid_classes if c in answer.lower()), "trash")
43
- return result.capitalize()
44
-
45
- title = "Smart Waste Material Detector"
46
- description = """
47
- Upload two images:
48
- 1. Base image (empty background)
49
- 2. Trash image (object placed on background)
50
-
51
- The AI will detect the difference and classify the material:
52
- plastic, metal, paper, cardboard, glass, or trash.
53
- """
54
-
55
- demo = gr.Interface(
56
- fn=identify_material,
57
- inputs=[
58
- gr.Image(type="pil", label="Upload Base Image (Empty)"),
59
- gr.Image(type="pil", label="Upload Trash Image")
60
- ],
61
- outputs=gr.Textbox(label="Detected Material"),
62
- title=title,
63
- description=description,
64
- allow_flagging="never"
65
  )
66
 
67
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
  import torch
5
+ from transformers import (
6
+ AutoModelForImageSegmentation,
7
+ AutoProcessor,
8
+ AutoFeatureExtractor,
9
+ AutoModelForImageClassification,
10
+ )
11
+
12
+ # === ① Load SAM model for segmentation ===
13
+ sam_model_id = "facebook/sam-vit-base"
14
+ processor_sam = AutoProcessor.from_pretrained(sam_model_id)
15
+ model_sam = AutoModelForImageSegmentation.from_pretrained(sam_model_id)
16
+
17
+ # === ② Load garbage classification model ===
18
+ cls_model_id = "yangy50/garbage-classification"
19
+ extractor = AutoFeatureExtractor.from_pretrained(cls_model_id)
20
+ cls_model = AutoModelForImageClassification.from_pretrained(cls_model_id)
21
+
22
+ base_img = None # Global memory for base image
23
+
24
+ # === Step 1: Set base ===
25
+ def set_base(image):
26
+ global base_img
27
+ if image is None:
28
+ return "Please upload an empty bin image."
29
+ base_img = image.convert("RGB")
30
+ return "✅ Base image saved successfully."
31
+
32
+ # === Step 2: Detect and classify trash ===
33
+ def detect_trash(image):
34
+ global base_img
35
+ if base_img is None:
36
+ return "Please set a base image first."
37
+
38
+ current_img = image.convert("RGB")
39
+
40
+ # Convert to numpy
41
+ base_np = np.array(base_img).astype(np.float32)
42
+ current_np = np.array(current_img).astype(np.float32)
43
+
44
+ # Difference mask
45
+ diff = np.abs(current_np - base_np).mean(axis=2)
46
+ mask = (diff > 40).astype(np.uint8) * 255 # threshold
47
+ mask_img = Image.fromarray(mask).convert("RGB")
48
+
49
+ # Use SAM to refine the mask
50
+ inputs = processor_sam(images=current_img, segmentation_maps=mask_img, return_tensors="pt")
51
+ with torch.no_grad():
52
+ outputs = model_sam(**inputs)
53
+ seg = outputs.pred_masks[0].cpu().numpy()
54
+
55
+ # Crop bounding box of detected trash
56
+ ys, xs = np.where(seg > 0.5)
57
+ if len(xs) == 0 or len(ys) == 0:
58
+ return "No significant object detected."
59
+
60
+ x1, x2, y1, y2 = xs.min(), xs.max(), ys.min(), ys.max()
61
+ cropped = current_img.crop((x1, y1, x2, y2))
62
+
63
+ # Classify the cropped object
64
+ cls_inputs = extractor(images=cropped, return_tensors="pt")
65
+ with torch.no_grad():
66
+ cls_out = cls_model(**cls_inputs)
67
+ probs = torch.nn.functional.softmax(cls_out.logits, dim=-1)
68
+ pred_idx = torch.argmax(probs, dim=-1).item()
69
+ pred_class = cls_model.config.id2label[pred_idx]
70
+
71
+ return f"🧩 Detected Material: {pred_class}"
72
+
73
+ # === Build UI ===
74
+ set_base_ui = gr.Interface(
75
+ fn=set_base,
76
+ inputs=gr.Image(type="pil", label="Upload Empty Bin (Base)"),
77
+ outputs=gr.Textbox(label="Status"),
78
+ title="🧩 Set Base",
79
+ )
80
+
81
+ detect_trash_ui = gr.Interface(
82
+ fn=detect_trash,
83
+ inputs=gr.Image(type="pil", label="Upload Trash Image"),
84
+ outputs=gr.Textbox(label="Detection Result"),
85
+ title="♻️ Detect & Classify Trash",
86
+ )
87
 
88
+ demo = gr.TabbedInterface(
89
+ [set_base_ui, detect_trash_ui],
90
+ ["Set Base", "Detect Trash"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
93
  if __name__ == "__main__":