Stust / app.py
Alexend's picture
Update app.py
e664788 verified
# ✅ app.py - 升級 TinyLlama-1.1B-Chat 版本
import json
import os
import gradio as gr
import faiss
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
# ✅ 檔案與模型設定
QA_FILE = "qa.json"
TEXT_FILE = "web_data.txt"
DOCS_FILE = "docs.json"
VECTOR_FILE = "faiss_index.faiss"
EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# ✅ 自動建構向量資料庫(若不存在)
if not (os.path.exists(VECTOR_FILE) and os.path.exists(DOCS_FILE)):
print("⚙️ 未偵測到向量資料庫,開始自動建構...")
with open(TEXT_FILE, "r", encoding="utf-8") as f:
content = f.read()
docs = [chunk.strip() for chunk in content.split("\n\n") if chunk.strip()]
embedder = SentenceTransformer(EMBED_MODEL)
embeddings = embedder.encode(docs, show_progress_bar=True)
index = faiss.IndexFlatL2(embeddings[0].shape[0])
index.add(embeddings)
faiss.write_index(index, VECTOR_FILE)
with open(DOCS_FILE, "w", encoding="utf-8") as f:
json.dump(docs, f, ensure_ascii=False, indent=2)
print("✅ 嵌入建構完成,共儲存段落:", len(docs))
# ✅ 載入資料與模型
with open(QA_FILE, "r", encoding="utf-8") as f:
qa_data = json.load(f)
with open(DOCS_FILE, "r", encoding="utf-8") as f:
docs = json.load(f)
index = faiss.read_index(VECTOR_FILE)
embedder = SentenceTransformer(EMBED_MODEL)
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(GEN_MODEL, trust_remote_code=True).to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# ✅ QA 快速匹配
def retrieve_qa_context(user_input):
for item in qa_data:
if item["match"] == "OR":
if any(k in user_input for k in item["keywords"]):
return item["response"]
elif item["match"] == "AND":
if all(k in user_input for k in item["keywords"]):
return item["response"]
return None
# ✅ 向量檢索 top-k 段落
def search_context_faiss(user_input, top_k=3):
vec = embedder.encode([user_input])
D, I = index.search(vec, top_k)
return "\n".join([docs[i] for i in I[0] if i < len(docs)])
# ✅ 使用 Few-shot Prompt 生成答案
def generate_answer(user_input, context):
prompt = f"""
你是一位了解南臺科技大學的智慧語音助理。請根據以下資料回答問題,僅用一至兩句話,以繁體中文表達,回答需清楚具體,不重複問題,不加入身份說明。
[範例格式]
問題:學校地址在哪裡?
回答:南臺科技大學位於台南市永康區南台街一號。
問題:學校電話是多少?
回答:總機電話是 06-2533131,電機工程系分機為 3301。
問題:電機工程系辦公室在哪?
回答:電機工程系辦公室位於 B 棟 B101。
問題:電機工程系有哪些組別?
回答:電機系設有控制組、生醫電子系統組與電能資訊組三個方向。
問題:學生社團活動如何?
回答:南臺有超過 80 個學生社團,涵蓋學術、康樂、服務、體育與藝術領域。
問題:圖書館提供哪些服務?
回答:圖書館提供借書、自修空間、期刊查詢與電子資源服務。
問題:師資如何?
回答:本校師資陣容堅強,擁有 30 多位教授、副教授與助理教授。
問題:悠活館是做什麼的?
回答:悠活館是學生休閒與運動中心,設有羽球場、健身房、桌球室等設施。
問題:怎麼到南臺科技大學?
回答:可從台南火車站搭乘公車,或經永康交流道開車約 10 分鐘抵達。
[資料]
{context}
[問題]
{user_input}
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=150)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
for line in response.splitlines()[::-1]:
if len(line.strip()) > 10 and not line.startswith("你是"):
return line.strip()
return response[-90:]
# ✅ 問答主流程
def answer(user_input):
direct = retrieve_qa_context(user_input)
if direct:
return direct
else:
context = search_context_faiss(user_input)
return generate_answer(user_input, context)
# ✅ Gradio 介面
interface = gr.Interface(
fn=answer,
inputs=gr.Textbox(lines=2, placeholder="請輸入與南臺科技大學相關的問題..."),
outputs="text",
title="南臺科技大學 問答機器人(TinyLlama 1.1B)",
description="支援 QA 關鍵字與語意檢索,自動建立嵌入庫,輸出繁體中文自然回答。",
theme="default"
)
interface.launch()