Alexend commited on
Commit
f4f1821
·
verified ·
1 Parent(s): 1e96aac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -69
app.py CHANGED
@@ -1,93 +1,110 @@
1
  import gradio as gr
2
- import os
3
  import torch
4
  import json
5
- import base64
6
  import tempfile
7
-
8
  from gtts import gTTS
9
- import whisper
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from sentence_transformers import SentenceTransformer
12
- import faiss
13
-
14
- # ---------- 模型與資料載入 ---------- #
15
-
16
- # 問答模型(輕量中文 LLM)
17
- LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
19
- model = AutoModelForCausalLM.from_pretrained(LLM_MODEL).eval()
20
-
21
- # 向量模型 + 向量資料庫
22
- embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
23
- index = faiss.read_index("vector_store.index")
24
-
25
- # 讀取文本資料(編號與原始句)
26
- with open("chunks.json", "r", encoding="utf-8") as f:
27
- chunks = json.load(f)
28
-
29
- # 語音辨識模型(Whisper)
30
- asr_model = whisper.load_model("base")
31
-
32
- # ---------- 問答處理 ---------- #
33
-
34
- def generate_answer(query):
35
- embedding = embedder.encode([query])
36
- D, I = index.search(embedding, k=3)
37
- context = "\n".join([chunks[i] for i in I[0]])
38
-
39
- prompt = f"你是一位語音問答助手,請根據下方資訊回答問題。\n\n資訊:\n{context}\n\n問題:{query}\n\n回答:"
40
-
41
- inputs = tokenizer(prompt, return_tensors="pt")
42
- outputs = model.generate(**inputs, max_new_tokens=128)
43
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
44
  return response.strip()
45
 
46
- # ---------- 語音處理 ---------- #
47
-
48
- def asr(audio_path):
49
- result = asr_model.transcribe(audio_path, language="zh")
50
- return result["text"]
51
-
52
- def tts(text):
53
- tts = gTTS(text, lang="zh")
54
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
55
- tts.save(fp.name)
56
- return fp.name
57
-
58
- # ---------- Pipeline ---------- #
59
-
60
- def chat_pipeline(audio_input=None, text_input=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if audio_input:
62
- text = asr(audio_input)
 
63
  elif text_input:
64
- text = text_input
65
  else:
66
- return "請輸入問題或語音", None
67
-
68
- answer = generate_answer(text)
69
- audio_out = tts(answer)
70
- return answer, audio_out
71
 
72
- # ---------- Gradio 介面 ---------- #
 
 
73
 
 
74
  with gr.Blocks() as demo:
75
- gr.Markdown("## 🎙️ 南臺科技大學 問答語音助理(TinyLlama + Whisper + RAG)")
76
 
77
  with gr.Row():
78
- audio_input = gr.Audio(source="microphone", type="filepath", label="🎤 語音提問")
79
- text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
80
 
81
  submit_btn = gr.Button("送出")
82
 
83
- output_text = gr.Textbox(label="AI 回答")
84
  output_audio = gr.Audio(label="語音播放", type="filepath")
85
 
86
- submit_btn.click(
87
- fn=chat_pipeline,
88
- inputs=[audio_input, text_input],
89
- outputs=[output_text, output_audio]
90
- )
91
 
92
  if __name__ == "__main__":
93
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
  import json
 
4
  import tempfile
5
+ import faiss
6
  from gtts import gTTS
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
8
  from sentence_transformers import SentenceTransformer
9
+ import numpy as np
10
+
11
+ # 模型
12
+ MODEL_NAME = "openbmb/MiniCPM-2B-sft-bf16"
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).eval()
15
+
16
+ # 語音辨識 Whisper
17
+ asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
18
+
19
+ # 向量模型
20
+ encoder = SentenceTransformer("shibing624/text2vec-base-chinese")
21
+ index = faiss.read_index("vector_store.faiss")
22
+ with open("documents.json", "r", encoding="utf-8") as f:
23
+ documents = json.load(f)
24
+
25
+ # QA固定問答(可選)
26
+ try:
27
+ with open("qa.json", "r", encoding="utf-8") as f:
28
+ qa_data = json.load(f)
29
+ except:
30
+ qa_data = []
31
+
32
+ # QA match(選擇性)
33
+ def match_qa(user_input):
34
+ cleaned_input = user_input.replace(" ", "")
35
+ for item in qa_data:
36
+ if item["match"] == "OR":
37
+ if any(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
38
+ return item["response"]
39
+ elif item["match"] == "AND":
40
+ if all(k.replace(" ", "") in cleaned_input for k in item["keywords"]):
41
+ return item["response"]
42
+ return None
43
+
44
+ # 文字生成
45
+ def generate_answer(text):
46
+ messages = [{"role": "user", "content": text}]
47
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
48
+ with torch.no_grad():
49
+ outputs = model.generate(input_ids, max_new_tokens=200)
50
+ response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
51
  return response.strip()
52
 
53
+ # 向量比對
54
+ def search_vector_db(query, top_k=1):
55
+ q_vec = encoder.encode([query])
56
+ D, I = index.search(np.array(q_vec), top_k)
57
+ results = [documents[i] for i in I[0] if i < len(documents)]
58
+ return results
59
+
60
+ # 回答邏輯整合
61
+ def answer(text):
62
+ # 1. QA 固定資料庫
63
+ fixed = match_qa(text)
64
+ if fixed:
65
+ return fixed
66
+
67
+ # 2. RAG 取資料輔助
68
+ related_docs = search_vector_db(text)
69
+ context = "\n".join(related_docs)
70
+ prompt = f"以下是一些關於南臺科技大學的資料:\n{context}\n\n根據上面的資料,請用中文簡短回答這個問題:{text}"
71
+ return generate_answer(prompt)
72
+
73
+ # TTS
74
+ def text_to_speech(text):
75
+ tts = gTTS(text, lang='zh')
76
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
77
+ tts.save(tmp.name)
78
+ return tmp.name
79
+
80
+ # 主流程
81
+ def voice_assistant(audio_input=None, text_input=None):
82
  if audio_input:
83
+ result = asr(audio_input)
84
+ user_text = result["text"]
85
  elif text_input:
86
+ user_text = text_input
87
  else:
88
+ return "請輸入語音或文字", None
 
 
 
 
89
 
90
+ response = answer(user_text)
91
+ speech_file = text_to_speech(response)
92
+ return response, speech_file
93
 
94
+ # Gradio UI
95
  with gr.Blocks() as demo:
96
+ gr.Markdown("## 🎓 南臺科技大學 AI 語音助理(MiniCPM + Whisper + 向量式 RAG)")
97
 
98
  with gr.Row():
99
+ mic = gr.Audio(source="microphone", type="filepath", label="語音輸入")
100
+ text_input = gr.Textbox(label="文字輸入", placeholder="請輸入您的問題")
101
 
102
  submit_btn = gr.Button("送出")
103
 
104
+ output_text = gr.Textbox(label="回答")
105
  output_audio = gr.Audio(label="語音播放", type="filepath")
106
 
107
+ submit_btn.click(fn=voice_assistant, inputs=[mic, text_input], outputs=[output_text, output_audio])
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
  demo.launch()