Alexend commited on
Commit
8b7d822
·
verified ·
1 Parent(s): f0470a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -114
app.py CHANGED
@@ -1,132 +1,111 @@
1
- # ✅ app.py - 升級 TinyLlama-1.1B-Chat 版本
2
-
3
- import json
4
- import os
5
  import gradio as gr
6
- import faiss
7
  import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
9
  from sentence_transformers import SentenceTransformer
10
-
11
- # ✅ 檔案與模型設定
12
- QA_FILE = "qa.json"
13
- TEXT_FILE = "web_data.txt"
14
- DOCS_FILE = "docs.json"
15
- VECTOR_FILE = "faiss_index.faiss"
16
- EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
17
- GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
-
19
- # ✅ 自動建構向量資料庫(若不存在)
20
- if not (os.path.exists(VECTOR_FILE) and os.path.exists(DOCS_FILE)):
21
- print("⚙️ 未偵測到向量資料庫,開始自動建構...")
22
- with open(TEXT_FILE, "r", encoding="utf-8") as f:
23
- content = f.read()
24
- docs = [chunk.strip() for chunk in content.split("\n\n") if chunk.strip()]
25
- embedder = SentenceTransformer(EMBED_MODEL)
26
- embeddings = embedder.encode(docs, show_progress_bar=True)
27
- index = faiss.IndexFlatL2(embeddings[0].shape[0])
28
- index.add(embeddings)
29
- faiss.write_index(index, VECTOR_FILE)
30
- with open(DOCS_FILE, "w", encoding="utf-8") as f:
31
- json.dump(docs, f, ensure_ascii=False, indent=2)
32
- print("✅ 嵌入建構完成,共儲存段落:", len(docs))
33
-
34
- # ✅ 載入資料與模型
35
- with open(QA_FILE, "r", encoding="utf-8") as f:
36
- qa_data = json.load(f)
37
- with open(DOCS_FILE, "r", encoding="utf-8") as f:
38
- docs = json.load(f)
39
- index = faiss.read_index(VECTOR_FILE)
40
- embedder = SentenceTransformer(EMBED_MODEL)
41
- tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
42
- model = AutoModelForCausalLM.from_pretrained(GEN_MODEL, trust_remote_code=True).to("cuda" if torch.cuda.is_available() else "cpu")
43
- model.eval()
44
-
45
- # ✅ QA 快速匹配
46
-
47
- def retrieve_qa_context(user_input):
48
  for item in qa_data:
49
  if item["match"] == "OR":
50
- if any(k in user_input for k in item["keywords"]):
51
  return item["response"]
52
  elif item["match"] == "AND":
53
- if all(k in user_input for k in item["keywords"]):
54
  return item["response"]
55
  return None
56
 
57
- # ✅ 向量檢索 top-k 段落
58
-
59
- def search_context_faiss(user_input, top_k=3):
60
- vec = embedder.encode([user_input])
61
- D, I = index.search(vec, top_k)
62
- return "\n".join([docs[i] for i in I[0] if i < len(docs)])
63
-
64
- # ✅ 使用 Few-shot Prompt 生成答案
65
-
66
- def generate_answer(user_input, context):
67
- prompt = f"""
68
- 你是一位了解南臺科技大學的智慧語音助理。請根據以下資料回答問題,僅用一至兩句話,以繁體中文表達,回答需清楚具體,不重複問題,不加入身份說明。
69
-
70
- [範例格式]
71
- 問題:學校地址在哪裡?
72
- 回答:南臺科技大學位於台南市永康區南台街一號。
73
-
74
- 問題:學校電話是多少?
75
- 回答:總機電話是 06-2533131,電機工程系分機為 3301。
76
-
77
- 問題:電機工程系辦公室在哪?
78
- 回答:電機工程系辦公室位於 B 棟 B101。
79
-
80
- 問題:電機工程系有哪些組別?
81
- 回答:電機系設有控制組、生醫電子系統組與電能資訊組三個方向。
82
-
83
- 問題:學生社團活動如何?
84
- 回答:南臺有超過 80 個學生社團,涵蓋學術、康樂、服務、體育與藝術領域。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- 問題:圖書館提供哪些服務?
87
- 回答:圖書館提供借書、自修空間、期刊查詢與電子資源服務。
 
88
 
89
- 問題:師資如何?
90
- 回答:本校師資陣容堅強,擁有 30 多位教授、副教授與助理教授。
 
91
 
92
- 問題:悠活館是做什麼的?
93
- 回答:悠活館是學生休閒與運動中心,設有羽球場、健身房、桌球室等設施。
 
94
 
95
- 問題:怎麼到南臺科技大學?
96
- 回答:可從台南火車站搭乘公車,或經永康交流道開車約 10 分鐘抵達。
97
 
98
- [資料]
99
- {context}
100
 
101
- [問題]
102
- {user_input}
103
- """
104
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
105
- outputs = model.generate(**inputs, max_new_tokens=150)
106
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
107
- for line in response.splitlines()[::-1]:
108
- if len(line.strip()) > 10 and not line.startswith("你是"):
109
- return line.strip()
110
- return response[-90:]
111
 
112
- # 問答主流程
 
113
 
114
- def answer(user_input):
115
- direct = retrieve_qa_context(user_input)
116
- if direct:
117
- return direct
118
- else:
119
- context = search_context_faiss(user_input)
120
- return generate_answer(user_input, context)
121
-
122
- # ✅ Gradio 介面
123
- interface = gr.Interface(
124
- fn=answer,
125
- inputs=gr.Textbox(lines=2, placeholder="請輸入與南臺科技大學相關的問題..."),
126
- outputs="text",
127
- title="南臺科技大學 問答機器人(TinyLlama 1.1B)",
128
- description="支援 QA 關鍵字與語意檢索,自動建立嵌入庫,輸出繁體中文自然回答。",
129
- theme="default"
130
- )
131
-
132
- interface.launch()
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import json
4
+ import tempfile
5
+ import faiss
6
+ from gtts import gTTS
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
+ import numpy as np
10
+
11
+ # 模型
12
+ MODEL_NAME = "openbmb/MiniCPM-2B-sft-bf16"
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval()
15
+
16
+ # 語音辨識 Whisper
17
+ asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
18
+
19
+ # 向量模型
20
+ encoder = SentenceTransformer("shibing624/text2vec-base-chinese")
21
+ index = faiss.read_index("vector_store.faiss")
22
+ with open("documents.json", "r", encoding="utf-8") as f:
23
+ documents = json.load(f)
24
+
25
+ # QA固定問答(可選)
26
+ try:
27
+ with open("qa.json", "r", encoding="utf-8") as f:
28
+ qa_data = json.load(f)
29
+ except:
30
+ qa_data = []
31
+
32
+ # QA match(選擇性)
33
+ def match_qa(user_input):
34
+ cleaned_input = user_input.replace(" ", "")
 
 
 
 
 
 
 
 
 
 
 
 
35
  for item in qa_data:
36
  if item["match"] == "OR":
37
+ if any(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
38
  return item["response"]
39
  elif item["match"] == "AND":
40
+ if all(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
41
  return item["response"]
42
  return None
43
 
44
+ # 文字生成
45
+ def generate_answer(text):
46
+ messages = [{"role": "user", "content": text}]
47
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
48
+ with torch.no_grad():
49
+ outputs = model.generate(input_ids, max_new_tokens=200)
50
+ response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
51
+ return response.strip()
52
+
53
+ # 向量比對
54
+ def search_vector_db(query, top_k=1):
55
+ q_vec = encoder.encode([query])
56
+ D, I = index.search(np.array(q_vec), top_k)
57
+ results = [documents[i] for i in I[0] if i < len(documents)]
58
+ return results
59
+
60
+ # 回答邏輯整合
61
+ def answer(text):
62
+ # 1. QA 固定資料庫
63
+ fixed = match_qa(text)
64
+ if fixed:
65
+ return fixed
66
+
67
+ # 2. RAG 取資料輔助
68
+ related_docs = search_vector_db(text)
69
+ context = "\n".join(related_docs)
70
+ prompt = f"以下是一些關於南臺科技大學的資料:\n{context}\n\n根據上面的資料,請用中文簡短回答這個問題:{text}"
71
+ return generate_answer(prompt)
72
+
73
+ # TTS
74
+ def text_to_speech(text):
75
+ tts = gTTS(text, lang='zh')
76
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
77
+ tts.save(tmp.name)
78
+ return tmp.name
79
+
80
+ # 主流程
81
+ def voice_assistant(audio_input=None, text_input=None):
82
+ if audio_input:
83
+ result = asr(audio_input)
84
+ user_text = result["text"]
85
+ elif text_input:
86
+ user_text = text_input
87
+ else:
88
+ return "請輸入語音或文字", None
89
 
90
+ response = answer(user_text)
91
+ speech_file = text_to_speech(response)
92
+ return response, speech_file
93
 
94
+ # Gradio UI
95
+ with gr.Blocks() as demo:
96
+ gr.Markdown("## 🎓 南臺科技大學 AI 語音助理(MiniCPM + Whisper + 向量式 RAG)")
97
 
98
+ with gr.Row():
99
+ mic = gr.Audio(source="microphone", type="filepath", label="語音輸入")
100
+ text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
101
 
102
+ submit_btn = gr.Button("送出")
 
103
 
104
+ output_text = gr.Textbox(label="回答")
105
+ output_audio = gr.Audio(label="語音播放", type="filepath")
106
 
107
+ submit_btn.click(fn=voice_assistant, inputs=[mic, text_input], outputs=[output_text, output_audio])
 
 
 
 
 
 
 
 
 
108
 
109
+ if __name__ == "__main__":
110
+ demo.launch()
111