20250920 / app.py
Alexend's picture
Update app.py
ed83a8d verified
import os
import tempfile
import requests
import gradio as gr
import torch
import whisper
from gtts import gTTS
from pydub import AudioSegment
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
from bs4 import BeautifulSoup
from duckduckgo_search import DDGS
import time
# ------------------------
# 模型初始化
# ------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model = whisper.load_model("base", device=device)
embedder = SentenceTransformer("all-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
lm_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
# ------------------------
# 信任網站(含行政院人事行政總處行事曆)
# ------------------------
TRUSTED_SITES = [
"https://www.dgpa.gov.tw", # 行政院人事行政總處 行事曆
"https://www.cna.com.tw", # 中央通訊社
"https://www.stust.edu.tw", # 南台科技大學
"https://www.moi.gov.tw",
"https://www.taiwan.net.tw",
"https://publicholidays.tw",
"https://zh.wikipedia.org", # 中文維基百科
]
# ------------------------
# 文字轉語音
# ------------------------
def text_to_speech(text):
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts = gTTS(text=text, lang="zh")
tts.save(fp.name)
return fp.name
# ------------------------
# 爬蟲 - 行政院人事行政總處 行事曆
# ------------------------
def crawl_official_calendar(query):
try:
url = "https://www.dgpa.gov.tw/holidaycalendar"
resp = requests.get(url, timeout=10)
soup = BeautifulSoup(resp.text, "html.parser")
rows = soup.find_all("tr")
for row in rows:
if query in row.text:
return row.text.strip()
except Exception as e:
return None
return None
# ------------------------
# 網路搜尋 (DuckDuckGo)
# ------------------------
def web_search(query):
try:
ddgs = DDGS()
results = ddgs.text(query, max_results=3)
answer = ""
for r in results:
answer += f"{r['title']} - {r['body']}\n"
return answer if answer else "查無資料。"
except Exception:
return "網路搜尋失敗。"
# ------------------------
# AI 回答
# ------------------------
def ai_answer(question):
inputs = tokenizer(question, return_tensors="pt").to(device)
outputs = lm_model.generate(inputs["input_ids"], max_length=150, do_sample=True, top_k=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ------------------------
# 主邏輯
# ------------------------
def qa_system(audio, text_input):
query = ""
if text_input: # 文字優先
query = text_input.strip()
elif audio is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
audio.export(tmpfile.name, format="wav")
result = whisper_model.transcribe(tmpfile.name)
query = result["text"]
else:
return "請提供語音或文字輸入。", None, ""
# 判斷是否包含「南台/南臺」
if "南台" in query or "南臺" in query:
# 使用 AI 回答
start_time = time.time()
answer = ai_answer(query)
elapsed = time.time() - start_time
if elapsed > 120: # 超過 120 秒 fallback
web_ans = web_search(query)
answer += f"\n(補充搜尋結果){web_ans}"
else:
# 節日 or 一般問題,先查可信網站
start_time = time.time()
answer = crawl_official_calendar(query)
if not answer:
elapsed = time.time() - start_time
if elapsed > 60:
answer = web_search(query)
else:
answer = "查無相關資料(可信網站)。"
# 生成語音回覆
audio_file = text_to_speech(answer)
return query, answer, audio_file
# ------------------------
# Gradio 介面
# ------------------------
with gr.Blocks() as demo:
gr.Markdown("# 南臺科技大學語音文字問答系統")
gr.Markdown("你可以輸入文字或錄音提問,系統會找出答案,並用語音回覆。")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="文字輸入(可選)", placeholder="請輸入你的問題...")
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="請上傳語音檔或錄音")
submit_btn = gr.Button("Submit")
with gr.Column():
text_output = gr.Textbox(label="語音辨識文字")
answer_output = gr.Textbox(label="AI 回答")
audio_output = gr.Audio(label="語音回覆", type="filepath")
submit_btn.click(
qa_system,
inputs=[audio_input, text_input],
outputs=[text_output, answer_output, audio_output]
)
if __name__ == "__main__":
demo.launch()