Alexend commited on
Commit
5330037
·
verified ·
1 Parent(s): 204d9ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -157
app.py CHANGED
@@ -1,184 +1,142 @@
1
  import os
2
- import json
3
- import time
4
- import base64
5
- import sqlite3
6
  import tempfile
7
  import requests
8
- from bs4 import BeautifulSoup
9
-
10
  import torch
11
  import whisper
12
  from gtts import gTTS
13
  from pydub import AudioSegment
14
- import numpy as np
15
  from sentence_transformers import SentenceTransformer, util
16
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
17
 
18
- import gradio as gr
19
-
20
- # =====================
21
- # 初始化模型
22
- # =====================
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  whisper_model = whisper.load_model("base", device=device)
25
- embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
26
- tokenizer = AutoTokenizer.from_pretrained("lmsys/mini-cpm-1b-sft")
27
- gen_model = AutoModelForCausalLM.from_pretrained("lmsys/mini-cpm-1b-sft").to(device)
28
 
29
- # =====================
30
- # 向量資料庫 (SQLite)
31
- # =====================
32
- DB_PATH = "vector_db.sqlite"
33
-
34
- def init_db():
35
- conn = sqlite3.connect(DB_PATH)
36
- c = conn.cursor()
37
- c.execute("""
38
- CREATE TABLE IF NOT EXISTS documents (
39
- id INTEGER PRIMARY KEY AUTOINCREMENT,
40
- content TEXT,
41
- embedding BLOB
42
- )
43
- """)
44
- conn.commit()
45
- conn.close()
46
-
47
- def add_document(content):
48
- embedding = embedding_model.encode(content)
49
- conn = sqlite3.connect(DB_PATH)
50
- c = conn.cursor()
51
- c.execute("INSERT INTO documents (content, embedding) VALUES (?, ?)",
52
- (content, embedding.tobytes()))
53
- conn.commit()
54
- conn.close()
55
-
56
- def search_similar(query, top_k=3):
57
- query_vec = embedding_model.encode(query)
58
- conn = sqlite3.connect(DB_PATH)
59
- c = conn.cursor()
60
- c.execute("SELECT content, embedding FROM documents")
61
- rows = c.fetchall()
62
- conn.close()
63
-
64
- if not rows:
65
- return []
66
-
67
- contents, embeddings = zip(*rows)
68
- embeddings = [np.frombuffer(e, dtype=np.float32) for e in embeddings]
69
- cos_scores = util.cos_sim(query_vec, embeddings)[0].cpu().numpy()
70
-
71
- ranked = sorted(zip(contents, cos_scores), key=lambda x: x[1], reverse=True)
72
- return [r[0] for r in ranked[:top_k]]
73
-
74
- # =====================
75
- # 爬蟲模組
76
- # =====================
77
  TRUSTED_SITES = [
78
  "https://www.dgpa.gov.tw", # 行政院人事行政總處 行事曆
79
- "https://www.cna.com.tw", # 中央通訊社
80
- "https://www.stust.edu.tw" # 南台科技大學
81
- "https://www.moi.gov.tw",
82
- "https://www.taiwan.net.tw",
83
- "https://publicholidays.tw",
84
  ]
85
 
86
- def crawl_trusted_sites(query):
87
- """先從信任網站搜尋"""
88
- results = []
89
- headers = {"User-Agent": "Mozilla/5.0"}
90
- for site in TRUSTED_SITES:
91
- try:
92
- resp = requests.get(site, headers=headers, timeout=10)
93
- if resp.status_code == 200:
94
- soup = BeautifulSoup(resp.text, "html.parser")
95
- text = " ".join([p.get_text(strip=True) for p in soup.find_all("p")])
96
- if query in text:
97
- results.append(f"[可信來源] {site}: {text[:200]}...")
98
- except Exception:
99
- continue
100
- return results
101
-
102
- def crawl_general_web(query):
103
- """一般網站搜尋"""
104
  try:
105
- url = f"https://www.google.com/search?q={query}"
106
- headers = {"User-Agent": "Mozilla/5.0"}
107
- resp = requests.get(url, headers=headers, timeout=10)
108
- if resp.status_code == 200:
109
- return [f"[一般搜尋結果] {url}"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  except Exception:
111
- return []
112
- return []
113
-
114
- # =====================
115
- # 問答邏輯
116
- # =====================
117
- def generate_answer(query, context=""):
118
- inputs = tokenizer(query + " " + context, return_tensors="pt").to(device)
119
- outputs = gen_model.generate(**inputs, max_new_tokens=128)
120
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
121
 
122
- def qa_pipeline(query):
123
- start_time = time.time()
124
-
125
- # 1. 南台相關 → AI
 
 
 
 
 
 
 
 
 
 
 
 
126
  if "南台" in query or "南臺" in query:
127
- context = " ".join(search_similar(query))
128
- if time.time() - start_time > 120:
129
- web_results = crawl_general_web(query)
130
- context += " ".join(web_results)
131
- return generate_answer(query, context)
132
-
133
- # 2. 一般問題 → 先查可信網站 (60 秒內)
134
- if time.time() - start_time <= 60:
135
- trusted_results = crawl_trusted_sites(query)
136
- if trusted_results:
137
- return "\n".join(trusted_results)
138
-
139
- # 3. 超過 60 秒 → 一般網路搜尋
140
- web_results = crawl_general_web(query)
141
- if web_results:
142
- return "\n".join(web_results)
143
-
144
- return "抱歉,目前無法找到相關資訊。"
145
-
146
- # =====================
 
 
 
147
  # Gradio 介面
148
- # =====================
149
- def chatbot_interface(audio=None, text=None):
150
- query = ""
151
- if audio:
152
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
153
- tmp.write(audio)
154
- tmp.flush()
155
- result = whisper_model.transcribe(tmp.name)
156
- query = result["text"].strip()
157
- if text:
158
- query = text.strip()
159
-
160
- if not query:
161
- return "請提供語音或文字輸入。"
162
-
163
- answer = qa_pipeline(query)
164
-
165
- # TTS
166
- tts = gTTS(answer, lang="zh")
167
- tts_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
168
- tts.save(tts_file.name)
169
- return answer, tts_file.name
170
-
171
  with gr.Blocks() as demo:
172
- gr.Markdown("# 🎤 南臺科技大學 智慧語音助理")
173
- with gr.Row():
174
- audio_input = gr.Audio(source="microphone", type="filepath", label="語音輸入 (可選)")
175
- text_input = gr.Textbox(label="文字輸入 (可選)", placeholder="請輸入問題...")
176
- submit_btn = gr.Button("送出")
177
- output_text = gr.Textbox(label="回答")
178
- output_audio = gr.Audio(label="語音回答", type="filepath")
179
 
180
- submit_btn.click(fn=chatbot_interface, inputs=[audio_input, text_input], outputs=[output_text, output_audio])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  if __name__ == "__main__":
183
- init_db()
184
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
 
 
 
2
  import tempfile
3
  import requests
4
+ import gradio as gr
 
5
  import torch
6
  import whisper
7
  from gtts import gTTS
8
  from pydub import AudioSegment
 
9
  from sentence_transformers import SentenceTransformer, util
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from bs4 import BeautifulSoup
12
+ from duckduckgo_search import DDGS
13
+ import time
14
 
15
+ # ------------------------
16
+ # 模型初始化
17
+ # ------------------------
 
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  whisper_model = whisper.load_model("base", device=device)
20
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
21
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
22
+ lm_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
23
 
24
+ # ------------------------
25
+ # 信任網站(含行政院人事行政總處行事曆)
26
+ # ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  TRUSTED_SITES = [
28
  "https://www.dgpa.gov.tw", # 行政院人事行政總處 行事曆
29
+ "https://zh.wikipedia.org", # 中文維基百科
 
 
 
 
30
  ]
31
 
32
+ # ------------------------
33
+ # 文字轉語音
34
+ # ------------------------
35
+ def text_to_speech(text):
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
37
+ tts = gTTS(text=text, lang="zh")
38
+ tts.save(fp.name)
39
+ return fp.name
40
+
41
+ # ------------------------
42
+ # 爬蟲 - 行政院人事行政總處 行事曆
43
+ # ------------------------
44
+ def crawl_official_calendar(query):
 
 
 
 
 
45
  try:
46
+ url = "https://www.dgpa.gov.tw/holidaycalendar"
47
+ resp = requests.get(url, timeout=10)
48
+ soup = BeautifulSoup(resp.text, "html.parser")
49
+ rows = soup.find_all("tr")
50
+ for row in rows:
51
+ if query in row.text:
52
+ return row.text.strip()
53
+ except Exception as e:
54
+ return None
55
+ return None
56
+
57
+ # ------------------------
58
+ # 網路搜尋 (DuckDuckGo)
59
+ # ------------------------
60
+ def web_search(query):
61
+ try:
62
+ ddgs = DDGS()
63
+ results = ddgs.text(query, max_results=3)
64
+ answer = ""
65
+ for r in results:
66
+ answer += f"{r['title']} - {r['body']}\n"
67
+ return answer if answer else "查無資料。"
68
  except Exception:
69
+ return "網路搜尋失敗。"
70
+
71
+ # ------------------------
72
+ # AI 回答
73
+ # ------------------------
74
+ def ai_answer(question):
75
+ inputs = tokenizer(question, return_tensors="pt").to(device)
76
+ outputs = lm_model.generate(inputs["input_ids"], max_length=150, do_sample=True, top_k=50)
 
77
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
78
 
79
+ # ------------------------
80
+ # 主邏輯
81
+ # ------------------------
82
+ def qa_system(audio, text_input):
83
+ query = ""
84
+ if text_input: # 文字優先
85
+ query = text_input.strip()
86
+ elif audio is not None:
87
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
88
+ audio.export(tmpfile.name, format="wav")
89
+ result = whisper_model.transcribe(tmpfile.name)
90
+ query = result["text"]
91
+ else:
92
+ return "請提供語音或文字輸入。", None, ""
93
+
94
+ # 判斷是否包含「南台/南臺」
95
  if "南台" in query or "南臺" in query:
96
+ # 使用 AI 回答
97
+ start_time = time.time()
98
+ answer = ai_answer(query)
99
+ elapsed = time.time() - start_time
100
+ if elapsed > 120: # 超過 120 秒 fallback
101
+ web_ans = web_search(query)
102
+ answer += f"\n(補充搜尋結果){web_ans}"
103
+ else:
104
+ # 節日 or 一般問題,先查可信網站
105
+ start_time = time.time()
106
+ answer = crawl_official_calendar(query)
107
+ if not answer:
108
+ elapsed = time.time() - start_time
109
+ if elapsed > 60:
110
+ answer = web_search(query)
111
+ else:
112
+ answer = "查無相關資料(可信網站)。"
113
+
114
+ # 生成語音回覆
115
+ audio_file = text_to_speech(answer)
116
+ return query, answer, audio_file
117
+
118
+ # ------------------------
119
  # Gradio 介面
120
+ # ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Blocks() as demo:
122
+ gr.Markdown("# 南臺科技大學語音文字問答系統")
123
+ gr.Markdown("你可以輸入文字或錄音提問,系統會找出答案,並用語音回覆。")
 
 
 
 
 
124
 
125
+ with gr.Row():
126
+ with gr.Column():
127
+ text_input = gr.Textbox(label="文字輸入(可選)", placeholder="請輸入你的問題...")
128
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="請上傳語音檔或錄音")
129
+ submit_btn = gr.Button("Submit")
130
+ with gr.Column():
131
+ text_output = gr.Textbox(label="語音辨識文字")
132
+ answer_output = gr.Textbox(label="AI 回答")
133
+ audio_output = gr.Audio(label="語音回覆", type="filepath")
134
+
135
+ submit_btn.click(
136
+ qa_system,
137
+ inputs=[audio_input, text_input],
138
+ outputs=[text_output, answer_output, audio_output]
139
+ )
140
 
141
  if __name__ == "__main__":
142
+ demo.launch()