Alexend commited on
Commit
e664788
·
verified ·
1 Parent(s): f4f1821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -93
app.py CHANGED
@@ -1,110 +1,132 @@
1
- import gradio as gr
2
- import torch
3
  import json
4
- import tempfile
 
5
  import faiss
6
- from gtts import gTTS
7
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
- import numpy as np
10
-
11
- # 模型
12
- MODEL_NAME = "openbmb/MiniCPM-2B-sft-bf16"
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval()
15
-
16
- # 語音辨識 Whisper
17
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
18
-
19
- # 向量模型
20
- encoder = SentenceTransformer("shibing624/text2vec-base-chinese")
21
- index = faiss.read_index("vector_store.faiss")
22
- with open("documents.json", "r", encoding="utf-8") as f:
23
- documents = json.load(f)
24
-
25
- # QA固定問答(可選)
26
- try:
27
- with open("qa.json", "r", encoding="utf-8") as f:
28
- qa_data = json.load(f)
29
- except:
30
- qa_data = []
31
-
32
- # QA match(選擇性)
33
- def match_qa(user_input):
34
- cleaned_input = user_input.replace(" ", "")
 
 
 
 
 
 
 
 
 
 
 
 
35
  for item in qa_data:
36
  if item["match"] == "OR":
37
- if any(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
38
  return item["response"]
39
  elif item["match"] == "AND":
40
- if all(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
41
  return item["response"]
42
  return None
43
 
44
- # 文字生成
45
- def generate_answer(text):
46
- messages = [{"role": "user", "content": text}]
47
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
48
- with torch.no_grad():
49
- outputs = model.generate(input_ids, max_new_tokens=200)
50
- response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
51
- return response.strip()
52
-
53
- # 向量比對
54
- def search_vector_db(query, top_k=1):
55
- q_vec = encoder.encode([query])
56
- D, I = index.search(np.array(q_vec), top_k)
57
- results = [documents[i] for i in I[0] if i < len(documents)]
58
- return results
59
-
60
- # 回答邏輯整合
61
- def answer(text):
62
- # 1. QA 固定資料庫
63
- fixed = match_qa(text)
64
- if fixed:
65
- return fixed
66
-
67
- # 2. RAG 取資料輔助
68
- related_docs = search_vector_db(text)
69
- context = "\n".join(related_docs)
70
- prompt = f"以下是一些關於南臺科技大學的資料:\n{context}\n\n根據上面的資料,請用中文簡短回答這個問題:{text}"
71
- return generate_answer(prompt)
72
-
73
- # TTS
74
- def text_to_speech(text):
75
- tts = gTTS(text, lang='zh')
76
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
77
- tts.save(tmp.name)
78
- return tmp.name
79
-
80
- # 主流程
81
- def voice_assistant(audio_input=None, text_input=None):
82
- if audio_input:
83
- result = asr(audio_input)
84
- user_text = result["text"]
85
- elif text_input:
86
- user_text = text_input
87
- else:
88
- return "請輸入語音或文字", None
89
 
90
- response = answer(user_text)
91
- speech_file = text_to_speech(response)
92
- return response, speech_file
93
 
94
- # Gradio UI
95
- with gr.Blocks() as demo:
96
- gr.Markdown("## 🎓 南臺科技大學 AI 語音助理(MiniCPM + Whisper + 向量式 RAG)")
97
 
98
- with gr.Row():
99
- mic = gr.Audio(source="microphone", type="filepath", label="語音輸入")
100
- text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
101
 
102
- submit_btn = gr.Button("送出")
 
103
 
104
- output_text = gr.Textbox(label="回答")
105
- output_audio = gr.Audio(label="語音播放", type="filepath")
106
 
107
- submit_btn.click(fn=voice_assistant, inputs=[mic, text_input], outputs=[output_text, output_audio])
 
108
 
109
- if __name__ == "__main__":
110
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - 升級 TinyLlama-1.1B-Chat 版本
2
+
3
  import json
4
+ import os
5
+ import gradio as gr
6
  import faiss
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from sentence_transformers import SentenceTransformer
10
+
11
+ # ✅ 檔案與模型設定
12
+ QA_FILE = "qa.json"
13
+ TEXT_FILE = "web_data.txt"
14
+ DOCS_FILE = "docs.json"
15
+ VECTOR_FILE = "faiss_index.faiss"
16
+ EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
17
+ GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
+
19
+ # ✅ 自動建構向量資料庫(若不存在)
20
+ if not (os.path.exists(VECTOR_FILE) and os.path.exists(DOCS_FILE)):
21
+ print("⚙️ 未偵測到向量資料庫,開始自動建構...")
22
+ with open(TEXT_FILE, "r", encoding="utf-8") as f:
23
+ content = f.read()
24
+ docs = [chunk.strip() for chunk in content.split("\n\n") if chunk.strip()]
25
+ embedder = SentenceTransformer(EMBED_MODEL)
26
+ embeddings = embedder.encode(docs, show_progress_bar=True)
27
+ index = faiss.IndexFlatL2(embeddings[0].shape[0])
28
+ index.add(embeddings)
29
+ faiss.write_index(index, VECTOR_FILE)
30
+ with open(DOCS_FILE, "w", encoding="utf-8") as f:
31
+ json.dump(docs, f, ensure_ascii=False, indent=2)
32
+ print("✅ 嵌入建構完成,共儲存段落:", len(docs))
33
+
34
+ # ✅ 載入資料與模型
35
+ with open(QA_FILE, "r", encoding="utf-8") as f:
36
+ qa_data = json.load(f)
37
+ with open(DOCS_FILE, "r", encoding="utf-8") as f:
38
+ docs = json.load(f)
39
+ index = faiss.read_index(VECTOR_FILE)
40
+ embedder = SentenceTransformer(EMBED_MODEL)
41
+ tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL, trust_remote_code=True)
42
+ model = AutoModelForCausalLM.from_pretrained(GEN_MODEL, trust_remote_code=True).to("cuda" if torch.cuda.is_available() else "cpu")
43
+ model.eval()
44
+
45
+ # ✅ QA 快速匹配
46
+
47
+ def retrieve_qa_context(user_input):
48
  for item in qa_data:
49
  if item["match"] == "OR":
50
+ if any(k in user_input for k in item["keywords"]):
51
  return item["response"]
52
  elif item["match"] == "AND":
53
+ if all(k in user_input for k in item["keywords"]):
54
  return item["response"]
55
  return None
56
 
57
+ # ✅ 向量檢索 top-k 段落
58
+
59
+ def search_context_faiss(user_input, top_k=3):
60
+ vec = embedder.encode([user_input])
61
+ D, I = index.search(vec, top_k)
62
+ return "\n".join([docs[i] for i in I[0] if i < len(docs)])
63
+
64
+ # ✅ 使用 Few-shot Prompt 生成答案
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def generate_answer(user_input, context):
67
+ prompt = f"""
68
+ 你是一位了解南臺科技大學的智慧語音助理。請根據以下資料回答問題,僅用一至兩句話,以繁體中文表達,回答需清楚具體,不重複問題,不加入身份說明。
69
 
70
+ [範例格式]
71
+ 問題:學校地址在哪裡?
72
+ 回答:南臺科技大學位於台南市永康區南台街一號。
73
 
74
+ 問題:學校電話是多少?
75
+ 回答:總機電話是 06-2533131,電機工程系分機為 3301。
 
76
 
77
+ 問題:電機工程系辦公室在哪?
78
+ 回答:電機工程系辦公室位於 B 棟 B101。
79
 
80
+ 問題:電機工程系有哪些組別?
81
+ 回答:電機系設有控制組、生醫電子系統組與電能資訊組三個方向。
82
 
83
+ 問題:學生社團活動如何?
84
+ 回答:南臺有超過 80 個學生社團,涵蓋學術、康樂、服務、體育與藝術領域。
85
 
86
+ 問題:圖書館提供哪些服務?
87
+ 回答:圖書館提供借書、自修空間、期刊查詢與電子資源服務。
88
+
89
+ 問題:師資如何?
90
+ 回答:本校師資陣容堅強,擁有 30 多位教授、副教授與助理教授。
91
+
92
+ 問題:悠活館是做什麼的?
93
+ 回答:悠活館是學生休閒與運動中心,設有羽球場、健身房、桌球室等設施。
94
+
95
+ 問題:怎麼到南臺科技大學?
96
+ 回答:可從台南火車站搭乘公車,或經永康交流道開車約 10 分鐘抵達。
97
+
98
+ [資料]
99
+ {context}
100
+
101
+ [問題]
102
+ {user_input}
103
+ """
104
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
105
+ outputs = model.generate(**inputs, max_new_tokens=150)
106
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
107
+ for line in response.splitlines()[::-1]:
108
+ if len(line.strip()) > 10 and not line.startswith("你是"):
109
+ return line.strip()
110
+ return response[-90:]
111
+
112
+ # ✅ 問答主流程
113
+
114
+ def answer(user_input):
115
+ direct = retrieve_qa_context(user_input)
116
+ if direct:
117
+ return direct
118
+ else:
119
+ context = search_context_faiss(user_input)
120
+ return generate_answer(user_input, context)
121
+
122
+ # ✅ Gradio 介面
123
+ interface = gr.Interface(
124
+ fn=answer,
125
+ inputs=gr.Textbox(lines=2, placeholder="請輸入與南臺科技大學相關的問題..."),
126
+ outputs="text",
127
+ title="南臺科技大學 問答機器人(TinyLlama 1.1B)",
128
+ description="支援 QA 關鍵字與語意檢索,自動建立嵌入庫,輸出繁體中文自然回答。",
129
+ theme="default"
130
+ )
131
+
132
+ interface.launch()