Spaces:
Sleeping
Sleeping
File size: 5,153 Bytes
9a3a6bd 66f5de7 9a3a6bd da9b217 66f5de7 50862fc 66f5de7 50862fc 66f5de7 50862fc 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 da9b217 c635b45 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd 66f5de7 9a3a6bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import gradio as gr
from PIL import Image
import json
import clip
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
import numpy as np
# =========================================================
# 1. 讀取台南垃圾分類資料庫 JSON
# =========================================================
with open("tainan_recycle_data.json", "r", encoding="utf-8") as f:
tainan_db = json.load(f)
def lookup_item(name):
"""模糊比對英文/中文項目"""
name = name.lower()
for item in tainan_db:
# 中文名稱比對
if name in item["name"].lower():
return item
# 英文名稱比對
if name in item["english_name"].lower():
return item
return None
# =========================================================
# 2. 載入模型:GroundingDINO (從你的 Dataset Repo)
# =========================================================
# ⚠️ 請把這裡的 repo_id 換成你的 HuggingFace Dataset
REPO_ID = "idkWhatToUse/groundingdino-weights"
config_path = hf_hub_download(
repo_id=REPO_ID,
filename="GroundingDINO_SwinT_OGC.py",
repo_type="dataset"
)
checkpoint_path = hf_hub_download(
repo_id=REPO_ID,
filename="groundingdino_swint_ogc.pth",
repo_type="dataset"
)
from groundingdino.util.inference import load_model, load_image, predict, annotate
dino_model = load_model(config_path, checkpoint_path)
# =========================================================
# 3. 載入 CLIP(分類判斷)+ LLM(分類理由)
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
labels = ["一般垃圾", "紙類", "塑膠類", "金屬類", "玻璃類", "食物垃圾", "電池", "電子產品"]
# LLM (phi-2)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
llm = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device)
def generate_reason(item_name, category):
prompt = f"物品「{item_name}」被歸類為「{category}」。請用一句話簡單解釋原因:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = llm.generate(**inputs, max_new_tokens=40)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# =========================================================
# 4. CLIP 分類(查不到官方資料時使用)
# =========================================================
def classify_clip(image, obj_name):
text_inputs = clip.tokenize(labels).to(device)
img_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
logits, _ = clip_model(img_tensor, text_inputs)
probs = logits.softmax(dim=-1).cpu().numpy()[0]
idx = probs.argmax()
category = labels[idx]
reason = generate_reason(obj_name, category)
return category, reason
# =========================================================
# 5. 主流程:物件偵測 → 查資料庫 → 顯示官方資訊
# =========================================================
def pipeline(image):
# Gradio 會回傳 numpy array,因此先轉換
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
image.save("temp.jpg")
# 1. GroundingDINO 偵測物品
img_np, img_tensor = load_image("temp.jpg")
boxes, logits, phrases = predict(
model=dino_model,
image=img_tensor,
caption="bottle, can, box, plastic, cup, phone, battery, appliance, metal",
box_threshold=0.3,
text_threshold=0.25
)
annotated = annotate(img_np, boxes, logits, phrases)
# 2. 組合輸出
result_text = ""
for obj in phrases:
obj_clean = obj.lower()
# (A)先查台南官方資料庫
match = lookup_item(obj_clean)
if match:
result_text += f"""
🧩 **物品:{match['name']}**
📘 英文名稱:{match['english_name']}
♻️ 回收指示:{match['recyclable']}
📖 官方說明:{match['notes']}
🌐 資料來源:台南市政府環保局
"""
continue
# (B)查不到 → 用 CLIP 推論 + LLM 理由
category, reason = classify_clip(image, obj_clean)
result_text += f"""
🧩 偵測到:{obj}
📦 分類推論:{category}
💡 理由:{reason}
"""
return annotated, result_text
# =========================================================
# 6. Gradio 介面
# =========================================================
with gr.Blocks() as demo:
gr.Markdown("# 🗑️ AI 垃圾分類助手(含台南官方資料庫)")
gr.Markdown("上傳照片 → 物件偵測 → 查台南市環保局資料庫 → 顯示回收建議")
img_input = gr.Image(label="上傳垃圾照片")
btn = gr.Button("開始分析 🚀", variant="primary")
img_output = gr.Image(label="物件偵測(GroundingDINO)")
text_output = gr.Textbox(label="分類結果", lines=10)
btn.click(pipeline, inputs=[img_input], outputs=[img_output, text_output])
demo.launch()
|