trashAI / app.py
idkWhatToUse's picture
Update app.py
557db1b verified
Raw
History Blame Contribute Delete
7.26 kB
import gradio as gr
import json
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from transformers import (
CLIPProcessor, CLIPModel,
AutoTokenizer, AutoModelForCausalLM
)
# =======================================
# 1. Load recycle data
# =======================================
recycle_data = json.load(open("recycle_data.json", "r", encoding="utf-8"))
label_texts, items = [], []
for item in recycle_data:
zh = item.get("name", "")
en = item.get("english_name") or ""
label_texts.append(f"{en}, {zh}" if en else zh)
items.append(item)
# =======================================
# 2. Load Q&A (RAG)
# =======================================
qas = json.load(open("qas.json", "r", encoding="utf-8"))
qa_questions = [q["question"] for q in qas]
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
qa_embeddings = embedder.encode(qa_questions, convert_to_tensor=True)
# =======================================
# 3. CLIP 用於圖片分類
# =======================================
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
with torch.no_grad():
t_inputs = clip_processor(text=label_texts, return_tensors="pt", padding=True)
text_embeds = clip_model.get_text_features(
input_ids=t_inputs["input_ids"],
attention_mask=t_inputs["attention_mask"]
)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# =======================================
# 4. LLM(Qwen 0.5B)+回答模板
# =======================================
LLM = "Qwen/Qwen2.5-0.5B-Instruct"
tok = AutoTokenizer.from_pretrained(LLM)
llm = AutoModelForCausalLM.from_pretrained(LLM, torch_dtype=torch.float32).to("cpu")
def llm_reply(prompt):
inputs = tok(prompt, return_tensors="pt")
outputs = llm.generate(**inputs, max_new_tokens=200)
return tok.decode(outputs[0], skip_special_tokens=True)
# =======================================
# 5. 回答品質:加入「專業垃圾分類助理模板」
# =======================================
def expert_llm_reply(text):
prompt = f"""
你是一位「台灣垃圾分類專家助理」。
請用 **自然、生活化、清楚、條列式、友善語氣** 回答問題。
遵守規則:
- 使用台灣常見分類(紙類、塑膠類、鐵鋁罐、玻璃、其他可回收、一般垃圾、廚餘…)
- 如可能需要清洗 → 提醒「保持乾淨、不要油膩」
- 如可能需要壓扁、拆蓋 → 主動提醒
- 如不同縣市規則不同 → 說「各縣市略有差異」
- 最後提供 1 個附加小提醒
使用者問題:{text}
請直接回答:
"""
return llm_reply(prompt)
# =======================================
# 6. 額外知識庫(讓回答更像真人)
# =======================================
extra_rules = {
"寶特瓶": [
"瓶身要簡單沖洗乾淨",
"可壓扁節省空間",
"瓶蓋需旋開分開丟(塑膠類)",
"標籤可保留或拆除都可以"
],
"鋁箔包": [
"要沖洗乾淨避免發臭",
"記得壓扁更好回收",
"屬於飲料紙容器類,可回收"
],
"外帶杯": [
"杯身要沖乾淨",
"若是紙杯 → 紙類回收",
"若是塑膠杯 → 塑膠類回收",
"吸管為一般垃圾"
],
"餐盒": [
"若為乾淨塑膠 → 可回收",
"若油膩、難清洗 → 一般垃圾",
"盒蓋通常可回收(塑膠)"
],
}
def add_extra_tips(item_name):
if item_name not in extra_rules:
return ""
tips = "\n".join(f"- {t}" for t in extra_rules[item_name])
return f"\n🔧 **小提醒:**\n{tips}"
# =======================================
# 7. 圖片分類 + 回答模板
# =======================================
def classify_image(pil):
inputs = clip_processor(images=pil, return_tensors="pt")
with torch.no_grad():
img_emb = clip_model.get_image_features(**inputs)
img_emb = img_emb / img_emb.norm(p=2, dim=-1, keepdim=True)
logits = img_emb @ text_embeds.T
probs = logits.softmax(dim=-1)[0]
idx = torch.argmax(probs).item()
score = float(probs[idx])
return idx, score
def smart_answer(item, score):
name = item["name"]
rec = item.get("recyclable", "")
notes = item.get("notes", "")
return f"""
🟢 **辨識結果**
我推測這張照片中的物品是 **{name}**
(相似度:**{score:.2f}**)
♻ **是否可回收**
{rec}
📌 **補充說明**
{notes}
{add_extra_tips(name)}
有需要我可以繼續告訴你:
- 要不要清洗?
- 要不要壓扁?
- 某些配件要不要拆?
都可以問我喔!
"""
# =======================================
# 8. 搜尋 recycle_data 名稱
# =======================================
def search_recycle_name(text):
for item in items:
if item["name"] in text:
return item
return None
# =======================================
# 9. RAG 搜尋官方 Q&A
# =======================================
def rag_search(text):
q_emb = embedder.encode(text, convert_to_tensor=True)
scores = util.cos_sim(q_emb, qa_embeddings)[0]
best_idx = torch.argmax(scores).item()
if float(scores[best_idx]) > 0.70:
return qas[best_idx]["answer"]
return None
# =======================================
# 10. Chatbot 主邏輯
# =======================================
global_image = None
def bot(message, history):
global global_image
# 如果含圖片
if isinstance(message, dict):
img = message.get("image", None)
text = message.get("text", "").strip()
# 上傳圖片 → 更新 context
if img is not None:
global_image = Image.fromarray(img)
idx, score = classify_image(global_image)
item = items[idx]
return smart_answer(item, score)
# 無圖片但有文字 → 當一般文字處理
message = text
# 純文字
if isinstance(message, str):
text = message.strip()
# 若有上一張圖片 → 可以追問
if global_image is not None:
idx, _ = classify_image(global_image)
current_item = items[idx]
if current_item["name"] in text:
return smart_answer(current_item, 0.99)
# recycle_data 查詢
item = search_recycle_name(text)
if item:
return smart_answer(item, 0.99)
# RAG 查官方資料
ans = rag_search(text)
if ans:
return f"📘 **官方資料:**\n{ans}"
# fallback → LLM 專業回答
return expert_llm_reply(text)
return "我好像不太理解你的訊息,可以再說一次嗎?"
# =======================================
# 11. Gradio Chat UI
# =======================================
ui = gr.ChatInterface(
fn=bot,
title="台南垃圾分類智慧助理(圖片 + 多輪聊天)",
description="你可以傳圖片或提問,我會查看 270+ 類回收資料 + 官方 Q&A + 多輪對話記憶。",
multimodal=True,
)
ui.launch()