|
|
import gradio as gr |
|
|
import requests |
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import time |
|
|
from dotenv import load_dotenv |
|
|
from rapidfuzz import process |
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import markdown |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") |
|
|
|
|
|
if not OPENROUTER_API_KEY: |
|
|
raise ValueError("❌ Lỗi: API Key không được tìm thấy trong .env!") |
|
|
|
|
|
|
|
|
GDRIVE_JSON_ID = "16f9wAF1Gkvy3Uxv6p2YS-ikINLs6JhRG" |
|
|
|
|
|
def download_from_drive(file_id, file_name): |
|
|
"""Tải file từ Google Drive về Hugging Face Spaces""" |
|
|
url = f"https://drive.google.com/uc?export=download&id={file_id}" |
|
|
if not os.path.exists(file_name): |
|
|
print(f"📥 Đang tải {file_name} từ Google Drive...") |
|
|
response = requests.get(url) |
|
|
with open(file_name, "wb") as f: |
|
|
f.write(response.content) |
|
|
print(f"✅ Tải thành công: {file_name}") |
|
|
else: |
|
|
print(f"✅ File {file_name} đã có sẵn.") |
|
|
|
|
|
|
|
|
download_from_drive(GDRIVE_JSON_ID, "processed_data.json") |
|
|
|
|
|
|
|
|
with open("processed_data.json", "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if not isinstance(data, list): |
|
|
raise ValueError("❌ Lỗi: Dữ liệu JSON không phải là danh sách.") |
|
|
|
|
|
|
|
|
texts = [chunk["text"] for chunk in data] |
|
|
vectorizer = TfidfVectorizer(stop_words="english", max_features=50000) |
|
|
tfidf_matrix = vectorizer.fit_transform(texts) |
|
|
|
|
|
|
|
|
def tfidf_search(query, top_k=5): |
|
|
"""Tìm kiếm tài liệu nhanh bằng TF-IDF + Cosine Similarity""" |
|
|
start_time = time.time() |
|
|
query_vector = vectorizer.transform([query]) |
|
|
similarity_scores = cosine_similarity(query_vector, tfidf_matrix).flatten() |
|
|
|
|
|
top_indices = similarity_scores.argsort()[-top_k:][::-1] |
|
|
results = [data[i] for i in top_indices if similarity_scores[i] > 0.1] |
|
|
|
|
|
print(f"✅ TF-IDF Search Time: {time.time() - start_time:.3f}s") |
|
|
return results |
|
|
|
|
|
|
|
|
def fuzzy_search(query, top_k=5): |
|
|
"""Tìm kiếm gần đúng bằng RapidFuzz""" |
|
|
start_time = time.time() |
|
|
matched_texts = process.extract(query, texts, limit=top_k, score_cutoff=75) |
|
|
results = [data[texts.index(match[0])] for match in matched_texts] |
|
|
|
|
|
print(f"✅ Fuzzy Search Time: {time.time() - start_time:.3f}s") |
|
|
return results |
|
|
|
|
|
|
|
|
def parallel_search(query, top_k=5): |
|
|
"""Chạy TF-IDF và Fuzzy Matching song song để tối ưu tốc độ""" |
|
|
with ThreadPoolExecutor() as executor: |
|
|
future_tfidf = executor.submit(tfidf_search, query, top_k) |
|
|
future_fuzzy = executor.submit(fuzzy_search, query, top_k) |
|
|
|
|
|
tfidf_results = future_tfidf.result() |
|
|
fuzzy_results = future_fuzzy.result() |
|
|
|
|
|
return tfidf_results + fuzzy_results |
|
|
|
|
|
|
|
|
def call_openrouter(prompt): |
|
|
"""Gửi câu hỏi đến OpenRouter API""" |
|
|
headers = { |
|
|
"Authorization": f"Bearer {OPENROUTER_API_KEY}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
payload = { |
|
|
"model": "qwen/qwen3-14b:free", |
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
"temperature": 0.2, |
|
|
"max_tokens": 30000 |
|
|
} |
|
|
|
|
|
response = requests.post("https://openrouter.ai/api/v1/chat/completions", |
|
|
json=payload, headers=headers) |
|
|
|
|
|
if response.status_code == 200: |
|
|
return response.json()["choices"][0]["message"]["content"] |
|
|
else: |
|
|
return f"❌ Lỗi API: {response.status_code}, {response.text}" |
|
|
|
|
|
|
|
|
def generate_research_report(query): |
|
|
"""Tạo báo cáo nghiên cứu từ dữ liệu JSON""" |
|
|
research_data = parallel_search(query) |
|
|
|
|
|
if not research_data: |
|
|
return "<p style='color: red; font-weight: bold;'>❌ Không tìm thấy dữ liệu nghiên cứu phù hợp.</p>" |
|
|
|
|
|
|
|
|
context = "\n\n".join([ |
|
|
f"**📄 {res['file_name']}** (Chunk {res['chunk_id']}):\n{res['text']}" for res in research_data |
|
|
]) |
|
|
|
|
|
prompt = f""" |
|
|
Bạn là một nhà nghiên cứu chuyên sâu về **đường sắt tốc độ cao** (HSR - High-Speed Rail). |
|
|
Hãy tổng hợp dữ liệu từ tập tin JSON dưới đây và viết một báo cáo nghiên cứu hoàn chỉnh. |
|
|
|
|
|
### **Dữ liệu thu thập được**: |
|
|
{context} |
|
|
|
|
|
Dựa trên dữ liệu này, hãy viết một **báo cáo nghiên cứu hoàn chỉnh** về chủ đề: |
|
|
**{query}** |
|
|
""" |
|
|
|
|
|
|
|
|
raw_markdown = call_openrouter(prompt) |
|
|
|
|
|
return raw_markdown |
|
|
|
|
|
|
|
|
chatbot_ui = gr.Interface( |
|
|
fn=generate_research_report, |
|
|
inputs=gr.Textbox(label="Nhập câu hỏi nghiên cứu"), |
|
|
outputs=gr.Markdown(label="Báo cáo nghiên cứu"), |
|
|
title="HSR RESEARCH AGENT 🚄", |
|
|
description="Nhập câu hỏi nghiên cứu của bạn và nhận báo cáo chi tiết về ngành đường sắt tốc độ cao.", |
|
|
theme="default" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
chatbot_ui.launch() |