File size: 5,153 Bytes
9a3a6bd
 
 
66f5de7
 
9a3a6bd
 
da9b217
66f5de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50862fc
66f5de7
 
 
50862fc
 
66f5de7
 
 
50862fc
 
66f5de7
 
 
 
 
 
 
 
 
9a3a6bd
 
 
 
 
66f5de7
9a3a6bd
 
 
66f5de7
9a3a6bd
 
 
 
 
 
 
66f5de7
 
 
 
9a3a6bd
 
 
 
 
 
 
 
 
 
 
66f5de7
9a3a6bd
 
66f5de7
 
 
 
da9b217
 
c635b45
9a3a6bd
 
66f5de7
9a3a6bd
 
 
 
66f5de7
9a3a6bd
 
 
 
 
66f5de7
 
 
9a3a6bd
66f5de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a3a6bd
 
 
 
66f5de7
9a3a6bd
 
66f5de7
9a3a6bd
66f5de7
9a3a6bd
66f5de7
 
9a3a6bd
66f5de7
9a3a6bd
 
66f5de7
 
9a3a6bd
66f5de7
9a3a6bd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import gradio as gr
from PIL import Image
import json
import clip
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
import numpy as np
# =========================================================
# 1. 讀取台南垃圾分類資料庫 JSON
# =========================================================
with open("tainan_recycle_data.json", "r", encoding="utf-8") as f:
    tainan_db = json.load(f)

def lookup_item(name):
    """模糊比對英文/中文項目"""
    name = name.lower()
    for item in tainan_db:
        # 中文名稱比對
        if name in item["name"].lower():
            return item
        # 英文名稱比對
        if name in item["english_name"].lower():
            return item
    return None


# =========================================================
# 2. 載入模型:GroundingDINO (從你的 Dataset Repo)
# =========================================================
# ⚠️ 請把這裡的 repo_id 換成你的 HuggingFace Dataset
REPO_ID = "idkWhatToUse/groundingdino-weights" 

config_path = hf_hub_download(
    repo_id=REPO_ID,
    filename="GroundingDINO_SwinT_OGC.py",
    repo_type="dataset"
)
checkpoint_path = hf_hub_download(
    repo_id=REPO_ID,
    filename="groundingdino_swint_ogc.pth",
    repo_type="dataset"
)

from groundingdino.util.inference import load_model, load_image, predict, annotate
dino_model = load_model(config_path, checkpoint_path)


# =========================================================
# 3. 載入 CLIP(分類判斷)+ LLM(分類理由)
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

labels = ["一般垃圾", "紙類", "塑膠類", "金屬類", "玻璃類", "食物垃圾", "電池", "電子產品"]

# LLM (phi-2)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
llm = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device)


def generate_reason(item_name, category):
    prompt = f"物品「{item_name}」被歸類為「{category}」。請用一句話簡單解釋原因:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = llm.generate(**inputs, max_new_tokens=40)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


# =========================================================
# 4. CLIP 分類(查不到官方資料時使用)
# =========================================================
def classify_clip(image, obj_name):
    text_inputs = clip.tokenize(labels).to(device)
    img_tensor = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        logits, _ = clip_model(img_tensor, text_inputs)
        probs = logits.softmax(dim=-1).cpu().numpy()[0]

    idx = probs.argmax()
    category = labels[idx]
    reason = generate_reason(obj_name, category)

    return category, reason


# =========================================================
# 5. 主流程:物件偵測 → 查資料庫 → 顯示官方資訊
# =========================================================
def pipeline(image):
    # Gradio 會回傳 numpy array,因此先轉換
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image.astype("uint8"))
    image.save("temp.jpg")

    # 1. GroundingDINO 偵測物品
    img_np, img_tensor = load_image("temp.jpg")
    boxes, logits, phrases = predict(
        model=dino_model,
        image=img_tensor,
        caption="bottle, can, box, plastic, cup, phone, battery, appliance, metal",
        box_threshold=0.3,
        text_threshold=0.25
    )
    annotated = annotate(img_np, boxes, logits, phrases)

    # 2. 組合輸出
    result_text = ""

    for obj in phrases:
        obj_clean = obj.lower()

        # (A)先查台南官方資料庫
        match = lookup_item(obj_clean)
        if match:
            result_text += f"""
🧩 **物品:{match['name']}**
📘 英文名稱:{match['english_name']}
♻️ 回收指示:{match['recyclable']}
📖 官方說明:{match['notes']}
🌐 資料來源:台南市政府環保局

"""
            continue

        # (B)查不到 → 用 CLIP 推論 + LLM 理由
        category, reason = classify_clip(image, obj_clean)

        result_text += f"""
🧩 偵測到:{obj}
📦 分類推論:{category}
💡 理由:{reason}

"""

    return annotated, result_text


# =========================================================
# 6. Gradio 介面
# =========================================================
with gr.Blocks() as demo:
    gr.Markdown("# 🗑️ AI 垃圾分類助手(含台南官方資料庫)")
    gr.Markdown("上傳照片 → 物件偵測 → 查台南市環保局資料庫 → 顯示回收建議")

    img_input = gr.Image(label="上傳垃圾照片")
    btn = gr.Button("開始分析 🚀", variant="primary")

    img_output = gr.Image(label="物件偵測(GroundingDINO)")
    text_output = gr.Textbox(label="分類結果", lines=10)

    btn.click(pipeline, inputs=[img_input], outputs=[img_output, text_output])

demo.launch()