devbernie's picture
Update Qwen model
9e9175c verified
import gradio as gr
import requests
import os
import json
import numpy as np
import time
from dotenv import load_dotenv
from rapidfuzz import process # 🔹 Thay thế FuzzyWuzzy bằng RapidFuzz (nhanh hơn)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from concurrent.futures import ThreadPoolExecutor # 🔹 Xử lý song song để giảm thời gian
import markdown
# --- Load API Key từ .env ---
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
raise ValueError("❌ Lỗi: API Key không được tìm thấy trong .env!")
# --- Google Drive File ID ---
GDRIVE_JSON_ID = "16f9wAF1Gkvy3Uxv6p2YS-ikINLs6JhRG"
def download_from_drive(file_id, file_name):
"""Tải file từ Google Drive về Hugging Face Spaces"""
url = f"https://drive.google.com/uc?export=download&id={file_id}"
if not os.path.exists(file_name):
print(f"📥 Đang tải {file_name} từ Google Drive...")
response = requests.get(url)
with open(file_name, "wb") as f:
f.write(response.content)
print(f"✅ Tải thành công: {file_name}")
else:
print(f"✅ File {file_name} đã có sẵn.")
# --- Tải dữ liệu JSON ---
download_from_drive(GDRIVE_JSON_ID, "processed_data.json")
# --- Đọc dữ liệu JSON ---
with open("processed_data.json", "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("❌ Lỗi: Dữ liệu JSON không phải là danh sách.")
# --- Chuẩn bị dữ liệu cho TF-IDF ---
texts = [chunk["text"] for chunk in data] # 🔹 Chỉ lấy nội dung văn bản
vectorizer = TfidfVectorizer(stop_words="english", max_features=50000) # 🔹 Giới hạn số từ đặc trưng
tfidf_matrix = vectorizer.fit_transform(texts) # 🔹 Chỉ số hóa một lần duy nhất
# --- Tìm kiếm bằng TF-IDF + Cosine Similarity ---
def tfidf_search(query, top_k=5):
"""Tìm kiếm tài liệu nhanh bằng TF-IDF + Cosine Similarity"""
start_time = time.time()
query_vector = vectorizer.transform([query])
similarity_scores = cosine_similarity(query_vector, tfidf_matrix).flatten()
top_indices = similarity_scores.argsort()[-top_k:][::-1]
results = [data[i] for i in top_indices if similarity_scores[i] > 0.1]
print(f"✅ TF-IDF Search Time: {time.time() - start_time:.3f}s")
return results
# --- Tìm kiếm gần đúng bằng RapidFuzz ---
def fuzzy_search(query, top_k=5):
"""Tìm kiếm gần đúng bằng RapidFuzz"""
start_time = time.time()
matched_texts = process.extract(query, texts, limit=top_k, score_cutoff=75) # 🔹 Chỉ lấy kết quả >75%
results = [data[texts.index(match[0])] for match in matched_texts]
print(f"✅ Fuzzy Search Time: {time.time() - start_time:.3f}s")
return results
# --- Tìm kiếm song song để giảm thời gian ---
def parallel_search(query, top_k=5):
"""Chạy TF-IDF và Fuzzy Matching song song để tối ưu tốc độ"""
with ThreadPoolExecutor() as executor:
future_tfidf = executor.submit(tfidf_search, query, top_k)
future_fuzzy = executor.submit(fuzzy_search, query, top_k)
tfidf_results = future_tfidf.result()
fuzzy_results = future_fuzzy.result()
return tfidf_results + fuzzy_results # 🔹 Kết hợp hai phương pháp tìm kiếm
# --- Gọi OpenRouter API ---
def call_openrouter(prompt):
"""Gửi câu hỏi đến OpenRouter API"""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "qwen/qwen3-14b:free",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.2,
"max_tokens": 30000
}
response = requests.post("https://openrouter.ai/api/v1/chat/completions",
json=payload, headers=headers)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
return f"❌ Lỗi API: {response.status_code}, {response.text}"
# --- Hàm chính để tạo báo cáo nghiên cứu ---
def generate_research_report(query):
"""Tạo báo cáo nghiên cứu từ dữ liệu JSON"""
research_data = parallel_search(query)
if not research_data:
return "<p style='color: red; font-weight: bold;'>❌ Không tìm thấy dữ liệu nghiên cứu phù hợp.</p>"
# Tổng hợp nội dung từ các chunk tìm được
context = "\n\n".join([
f"**📄 {res['file_name']}** (Chunk {res['chunk_id']}):\n{res['text']}" for res in research_data
])
prompt = f"""
Bạn là một nhà nghiên cứu chuyên sâu về **đường sắt tốc độ cao** (HSR - High-Speed Rail).
Hãy tổng hợp dữ liệu từ tập tin JSON dưới đây và viết một báo cáo nghiên cứu hoàn chỉnh.
### **Dữ liệu thu thập được**:
{context}
Dựa trên dữ liệu này, hãy viết một **báo cáo nghiên cứu hoàn chỉnh** về chủ đề:
**{query}**
"""
# Gọi OpenRouter API để tạo báo cáo
raw_markdown = call_openrouter(prompt)
return raw_markdown # 🔹 Xuất ra Markdown
# --- Giao diện Gradio ---
chatbot_ui = gr.Interface(
fn=generate_research_report,
inputs=gr.Textbox(label="Nhập câu hỏi nghiên cứu"),
outputs=gr.Markdown(label="Báo cáo nghiên cứu"),
title="HSR RESEARCH AGENT 🚄",
description="Nhập câu hỏi nghiên cứu của bạn và nhận báo cáo chi tiết về ngành đường sắt tốc độ cao.",
theme="default"
)
# --- Chạy ứng dụng ---
if __name__ == "__main__":
chatbot_ui.launch()