Alexend commited on
Commit
4ce0569
·
verified ·
1 Parent(s): 2b729ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -87
app.py CHANGED
@@ -1,111 +1,93 @@
1
  import gradio as gr
 
2
  import torch
3
  import json
 
4
  import tempfile
5
- import faiss
6
  from gtts import gTTS
7
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
8
  from sentence_transformers import SentenceTransformer
9
- import numpy as np
10
-
11
- # 模型
12
- MODEL_NAME = "openbmb/MiniCPM-2B-sft-bf16"
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval()
15
-
16
- # 語音辨識 Whisper
17
- asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
18
-
19
- # 向量模型
20
- encoder = SentenceTransformer("shibing624/text2vec-base-chinese")
21
- index = faiss.read_index("vector_store.faiss")
22
- with open("documents.json", "r", encoding="utf-8") as f:
23
- documents = json.load(f)
24
-
25
- # QA固定問答(可選)
26
- try:
27
- with open("qa.json", "r", encoding="utf-8") as f:
28
- qa_data = json.load(f)
29
- except:
30
- qa_data = []
31
-
32
- # QA match(選擇性)
33
- def match_qa(user_input):
34
- cleaned_input = user_input.replace(" ", "")
35
- for item in qa_data:
36
- if item["match"] == "OR":
37
- if any(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
38
- return item["response"]
39
- elif item["match"] == "AND":
40
- if all(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
41
- return item["response"]
42
- return None
43
-
44
- # 文字生成
45
- def generate_answer(text):
46
- messages = [{"role": "user", "content": text}]
47
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
48
- with torch.no_grad():
49
- outputs = model.generate(input_ids, max_new_tokens=200)
50
- response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
51
  return response.strip()
52
 
53
- # 向量比對
54
- def search_vector_db(query, top_k=1):
55
- q_vec = encoder.encode([query])
56
- D, I = index.search(np.array(q_vec), top_k)
57
- results = [documents[i] for i in I[0] if i < len(documents)]
58
- return results
59
-
60
- # 回答邏輯整合
61
- def answer(text):
62
- # 1. QA 固定資料庫
63
- fixed = match_qa(text)
64
- if fixed:
65
- return fixed
66
-
67
- # 2. RAG 取資料輔助
68
- related_docs = search_vector_db(text)
69
- context = "\n".join(related_docs)
70
- prompt = f"以下是一些關於南臺科技大學的資料:\n{context}\n\n根據上面的資料,請用中文簡短回答這個問題:{text}"
71
- return generate_answer(prompt)
72
-
73
- # TTS
74
- def text_to_speech(text):
75
- tts = gTTS(text, lang='zh')
76
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
77
- tts.save(tmp.name)
78
- return tmp.name
79
-
80
- # 主流程
81
- def voice_assistant(audio_input=None, text_input=None):
82
  if audio_input:
83
- result = asr(audio_input)
84
- user_text = result["text"]
85
  elif text_input:
86
- user_text = text_input
87
  else:
88
- return "請輸入語音或文字", None
89
 
90
- response = answer(user_text)
91
- speech_file = text_to_speech(response)
92
- return response, speech_file
 
 
93
 
94
- # Gradio UI
95
  with gr.Blocks() as demo:
96
- gr.Markdown("## 🎓 南臺科技大學 AI 語音助理(MiniCPM + Whisper + 向量式 RAG)")
97
 
98
  with gr.Row():
99
- mic = gr.Audio(source="microphone", type="filepath", label="語音輸入")
100
- text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
101
 
102
  submit_btn = gr.Button("送出")
103
 
104
- output_text = gr.Textbox(label="回答")
105
  output_audio = gr.Audio(label="語音播放", type="filepath")
106
 
107
- submit_btn.click(fn=voice_assistant, inputs=[mic, text_input], outputs=[output_text, output_audio])
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
  demo.launch()
111
-
 
1
  import gradio as gr
2
+ import os
3
  import torch
4
  import json
5
+ import base64
6
  import tempfile
7
+
8
  from gtts import gTTS
9
+ import whisper
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from sentence_transformers import SentenceTransformer
12
+ import faiss
13
+
14
+ # ---------- 模型與資料載入 ---------- #
15
+
16
+ # 問答模型(輕量中文 LLM)
17
+ LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
19
+ model = AutoModelForCausalLM.from_pretrained(LLM_MODEL).eval()
20
+
21
+ # 向量模型 + 向量資料庫
22
+ embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
23
+ index = faiss.read_index("vector_store.index")
24
+
25
+ # 讀取文本資料(編號與原始句)
26
+ with open("chunks.json", "r", encoding="utf-8") as f:
27
+ chunks = json.load(f)
28
+
29
+ # 語音辨識模型(Whisper)
30
+ asr_model = whisper.load_model("base")
31
+
32
+ # ---------- 問答處理 ---------- #
33
+
34
+ def generate_answer(query):
35
+ embedding = embedder.encode([query])
36
+ D, I = index.search(embedding, k=3)
37
+ context = "\n".join([chunks[i] for i in I[0]])
38
+
39
+ prompt = f"你是一位語音問答助手,請根據下方資訊回答問題。\n\n資訊:\n{context}\n\n問題:{query}\n\n回答:"
40
+
41
+ inputs = tokenizer(prompt, return_tensors="pt")
42
+ outputs = model.generate(**inputs, max_new_tokens=128)
43
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
44
  return response.strip()
45
 
46
+ # ---------- 語音處理 ---------- #
47
+
48
+ def asr(audio_path):
49
+ result = asr_model.transcribe(audio_path, language="zh")
50
+ return result["text"]
51
+
52
+ def tts(text):
53
+ tts = gTTS(text, lang="zh")
54
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
55
+ tts.save(fp.name)
56
+ return fp.name
57
+
58
+ # ---------- Pipeline ---------- #
59
+
60
+ def chat_pipeline(audio_input=None, text_input=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if audio_input:
62
+ text = asr(audio_input)
 
63
  elif text_input:
64
+ text = text_input
65
  else:
66
+ return "請輸入問題或語音", None
67
 
68
+ answer = generate_answer(text)
69
+ audio_out = tts(answer)
70
+ return answer, audio_out
71
+
72
+ # ---------- Gradio 介面 ---------- #
73
 
 
74
  with gr.Blocks() as demo:
75
+ gr.Markdown("## 🎙️ 南臺科技大學 問答語音助理(TinyLlama + Whisper + RAG)")
76
 
77
  with gr.Row():
78
+ audio_input = gr.Audio(source="microphone", type="filepath", label="🎤 語音提問")
79
+ text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
80
 
81
  submit_btn = gr.Button("送出")
82
 
83
+ output_text = gr.Textbox(label="AI 回答")
84
  output_audio = gr.Audio(label="語音播放", type="filepath")
85
 
86
+ submit_btn.click(
87
+ fn=chat_pipeline,
88
+ inputs=[audio_input, text_input],
89
+ outputs=[output_text, output_audio]
90
+ )
91
 
92
  if __name__ == "__main__":
93
  demo.launch()