File size: 4,068 Bytes
dfda82a
559baca
9a8bf54
402892d
18f4f0a
 
4f82669
ed49233
402892d
9f3b645
ed49233
9f3b645
 
5dcf751
9f3b645
 
27f41b4
18f4f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3b645
 
18f4f0a
 
 
 
9f3b645
 
64e9ea2
18f4f0a
5be7dab
402892d
9f3b645
27f41b4
18f4f0a
64e9ea2
9f3b645
18f4f0a
9f3b645
 
 
 
 
 
 
18f4f0a
64e9ea2
18f4f0a
9f3b645
 
18f4f0a
 
 
9f3b645
 
18f4f0a
64e9ea2
402892d
 
18f4f0a
9f3b645
18f4f0a
402892d
27f41b4
44dbc68
9f3b645
 
 
 
 
 
402892d
5be7dab
9f3b645
402892d
9f3b645
27f41b4
18f4f0a
9f3b645
3338766
9f3b645
3338766
 
 
9f3b645
 
3338766
 
 
5dc4e7f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import requests
import os
import json
import fitz  # PyMuPDF
from pathlib import Path

# --- 核心配置 ---
OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY")
MODEL_ID = "google/gemini-2.0-flash-001" 

# --- 你的专属 HTML 声明 (保持不变) ---
INFO_HTML = """<div style="text-align: left; border-left: 4px solid #2196F3; padding-left: 15px; margin-bottom: 20px;">
    <h3>MG TaxAI | 跨境财税合规实验室 (Beta)</h3>
    <p>本系统依托 <b>MG 核心智库</b> 构建...</p>
</div>"""

# --- 深度知识库检索引擎 (RAG) ---
def get_knowledge_context(query):
    context_chunks = []
    base_dirs = ["Treaties", "InvestmentGuide"]
    keywords = [word for word in query.split() if len(word) > 1]
    
    for folder in base_dirs:
        path = Path(folder)
        if not path.exists(): continue
        
        for pdf_file in path.rglob("*.pdf"):
            if any(kw.lower() in pdf_file.name.lower() for kw in keywords):
                try:
                    with fitz.open(pdf_file) as doc:
                        # 增加至前 3 页,获取更多上下文
                        text = "".join([page.get_text() for page in doc[:3]])
                        context_chunks.append(f"来自文件 [{pdf_file.name}]:\n{text}")
                except:
                    continue
    
    # 修正:确保在遍历完所有文件夹后再返回
    return "\n\n".join(context_chunks)[:6000] 

# --- API 专家级调用逻辑 ---
def ask_ai(message, history):
    if not OPENROUTER_API_KEY:
        return "⚠️ 未检测到 API Key,请在 Space 的 Settings -> Secrets 中添加。"

    local_context = get_knowledge_context(message)

    # 强化版系统指令:加入避险逻辑和专业深度
    system_instruction = """
    你是一位资深的 MG Consulting 国际税务专家级 AI。
    
    【核心准则】:
    1. 专业性:优先引用参考知识库。若背景不足,基于 2025-2026 最新全球财税准则回答。
    2. 避险:在讨论行业趋势时,使用“大型咨询机构”或“核心智库”等统称,**严禁提及具体的国际会计师事务所名称**。
    3. 风格:直接进入分析,不进行冗长的自我介绍,使用 Markdown 格式(标题、列表、粗体)。
    4. 深度:分析需涵盖税种差异(Income Tax, VAT, Withholding Tax)及双边协定(DTA)影响。
    """

    messages = [{"role": "system", "content": system_instruction}]
    
    # Gradio 的 history 已经是 list of tuples
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": assistant_msg})

    current_input = f"【参考知识库】:\n{local_context}\n\n【用户咨询】:\n{message}"
    messages.append({"role": "user", "content": current_input})

    payload = {
        "model": MODEL_ID,
        "messages": messages,
        "temperature": 0.2, 
        "top_p": 0.9
    }

    try:
        response = requests.post(
            "https://openrouter.ai/api/v1/chat/completions",
            headers={"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"},
            data=json.dumps(payload),
            timeout=60
        )
        if response.status_code == 200:
            return response.json()['choices'][0]['message']['content']
        return f"❌ 接口响应异常 ({response.status_code})"
    except Exception as e:
        return f"💥 系统连接超时: {str(e)}"

# --- 界面构建 ---
with gr.Blocks(title="MG TaxAI Lab", fill_height=True) as demo:
    gr.HTML(INFO_HTML)
    gr.ChatInterface(
        fn=ask_ai,
        fill_height=True,
        retry_btn="🔄 重新生成",
        undo_btn="↩️ 撤回",
        clear_btn="🗑️ 清空",
    )

if __name__ == "__main__":
    # 强制监听所有 IP 且关闭 share(HF 内部不支持 share=True)
    demo.launch(
        server_name="0.0.0.0", 
        server_port=7860,
        show_error=True
    )