MGGroup commited on
Commit
559baca
·
verified ·
1 Parent(s): 8567504

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -52
app.py CHANGED
@@ -1,22 +1,14 @@
1
  import gradio as gr
 
2
  import os
3
  import fitz
4
  import re
5
- import torch
6
- from transformers import pipeline
7
 
8
- # --- 核心配置:不再使用 API,直接本地加载 ---
9
- # 免费 Space 内存有限,只能跑 1.5B 规模的模型但逻辑足够处理税务检索
10
- MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
11
-
12
- print("正在初始化本地模型,请稍候...")
13
- pipe = pipeline(
14
- "text-generation",
15
- model=MODEL_ID,
16
- torch_dtype=torch.float32,
17
- device_map="auto"
18
- )
19
- print("模型加载完成!")
20
 
21
  def find_treaty_context(query):
22
  base_dir = "./treaties"
@@ -43,72 +35,69 @@ def find_treaty_context(query):
43
  break
44
 
45
  if not selected_folder:
46
- return f"❓ 未识别到国家文件夹 (库内已加载 {len(country_folders)} 个国家)", ""
47
 
48
  folder_path = os.path.join(base_dir, selected_folder)
49
  context_list = []
50
  try:
51
  pdf_files = [f for f in os.listdir(folder_path) if f.lower().endswith(".pdf")]
52
- search_terms = ["利息", "股息", "预提", "所得税", "Tax", "WHT", "Rate", "Article 10", "Article 11"]
53
  for pdf_file in pdf_files:
54
  doc = fitz.open(os.path.join(folder_path, pdf_file))
55
- text = f"\n--- 来源文件: {pdf_file} ---\n"
56
- target_pages = []
57
- for page_num in range(min(50, len(doc))):
58
- page_text = doc[page_num].get_text()
59
- if any(term in page_text for term in search_terms):
60
- target_pages.append(page_num)
61
- if not target_pages: target_pages = list(range(min(5, len(doc))))
62
- for p in sorted(list(set(target_pages)))[:5]:
63
- text += doc[p].get_text()
64
  context_list.append(text)
65
  doc.close()
66
  except Exception as e:
67
- return f"❌ 深度检索失败: {str(e)}", ""
68
 
69
- return f"✅ 已自动识别并检索【{selected_folder}】资料", "\n".join(context_list)[:3000]
70
 
71
  def respond(message, history, system_message, max_tokens, temperature, top_p):
72
- # 直接用 os.environ 获取,不调用 huggingface_hub 的内部方法
73
- token = os.environ.get("HF_TOKEN")
74
- if not token:
75
  yield "❌ [配置错误] 请在 Secrets 中添加 HF_TOKEN。"
76
  return
77
 
78
- if treaty_knowledge:
79
- prompt = f"<|im_start|>system\n{system_message}\n资料库:{treaty_knowledge}<|im_end|>\n"
80
- else:
81
- prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n"
82
 
 
 
83
  for user_msg, assistant_msg in history:
84
- prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
85
-
86
- prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
 
 
 
 
 
87
 
88
  try:
89
- outputs = pipe(
90
- prompt,
91
- max_new_tokens=max_tokens,
92
- do_sample=True,
93
- temperature=0.01,
94
- top_p=top_p
95
- )
96
- # 提取回复内容
97
- full_text = outputs[0]['generated_text']
98
- reply = full_text.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
99
- yield f"🔍 **[检索状态]**: {status_msg}\n\n---\n\n{reply}"
100
  except Exception as e:
101
- yield f"⚠️ 本地推理异常: {str(e)}"
102
 
103
- description_text = """<div style="text-align: left;"><h3>MG TaxAI (本地版)</h3><p>此版本运行于 Space 本地 CPU,不受 API 额度限制。</p></div>"""
104
 
105
  demo = gr.ChatInterface(
106
  fn=respond,
107
- title="跨境财税合规实验室 (Beta)",
108
  description=description_text,
109
  additional_inputs=[
110
  gr.Textbox(value="你代表 MG Consult,是国际税收专家。", label="系统指令"),
111
- gr.Slider(256, 2048, 1024, label="最大字数"),
112
  gr.Slider(0, 1, 0.01, label="严谨度"),
113
  gr.Slider(0, 1, 0.95, label="采样率"),
114
  ],
 
1
  import gradio as gr
2
+ import requests
3
  import os
4
  import fitz
5
  import re
 
 
6
 
7
+ # --- 核心路由配置 ---
8
+ # 换成 Mistral-7B v0.3 版本这个接口目前在 HF 免费网关中比 Qwen 稳一点
9
+ API_URL = "https://router.huggingface.co/v1/chat/completions"
10
+ MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
12
 
13
  def find_treaty_context(query):
14
  base_dir = "./treaties"
 
35
  break
36
 
37
  if not selected_folder:
38
+ return f"❓ 未识别到国家 (库内已加载 {len(country_folders)} 个)", ""
39
 
40
  folder_path = os.path.join(base_dir, selected_folder)
41
  context_list = []
42
  try:
43
  pdf_files = [f for f in os.listdir(folder_path) if f.lower().endswith(".pdf")]
 
44
  for pdf_file in pdf_files:
45
  doc = fitz.open(os.path.join(folder_path, pdf_file))
46
+ text = f"\n--- 来源: {pdf_file} ---\n"
47
+ # 简化扫描,只取前 15 页防止文本过长导致 402
48
+ for page_num in range(min(15, len(doc))):
49
+ text += doc[page_num].get_text()
 
 
 
 
 
50
  context_list.append(text)
51
  doc.close()
52
  except Exception as e:
53
+ return f"❌ 检索失败: {str(e)}", ""
54
 
55
+ return f"✅ 已关联【{selected_folder}】资料", "\n".join(context_list)[:3500]
56
 
57
  def respond(message, history, system_message, max_tokens, temperature, top_p):
58
+ if not HF_TOKEN:
 
 
59
  yield "❌ [配置错误] 请在 Secrets 中添加 HF_TOKEN。"
60
  return
61
 
62
+ status_msg, treaty_knowledge = find_treaty_context(message)
 
 
 
63
 
64
+ # 构造请求消息
65
+ messages = [{"role": "system", "content": f"{system_message}\n\n参考资料:\n{treaty_knowledge}"}]
66
  for user_msg, assistant_msg in history:
67
+ if user_msg: messages.append({"role": "user", "content": user_msg})
68
+ if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg})
69
+ messages.append({"role": "user", "content": message})
70
+
71
+ headers = {"Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json"}
72
+ payload = {
73
+ "model": MODEL_ID,
74
+ "messages": messages,
75
+ "max_tokens": max_tokens,
76
+ "temperature": 0.01,
77
+ "top_p": top_p,
78
+ "stream": False
79
+ }
80
 
81
  try:
82
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
83
+ if response.status_code == 200:
84
+ result = response.json()
85
+ reply = result['choices'][0]['message']['content']
86
+ yield f"🔍 **[状态]**: {status_msg}\n\n---\n\n{reply}"
87
+ else:
88
+ yield f"⚠️ 接口响应({response.status_code}): {response.text}"
 
 
 
 
89
  except Exception as e:
90
+ yield f"⚠️ 系统错误: {str(e)}"
91
 
92
+ description_text = """<div style="text-align: left;"><h3>MG TaxAI | 跨境财税合规实验室 (Beta)</h3></div>"""
93
 
94
  demo = gr.ChatInterface(
95
  fn=respond,
96
+ title="跨境财税合规实验室",
97
  description=description_text,
98
  additional_inputs=[
99
  gr.Textbox(value="你代表 MG Consult,是国际税收专家。", label="系统指令"),
100
+ gr.Slider(256, 4096, 2048, label="最大字数"),
101
  gr.Slider(0, 1, 0.01, label="严谨度"),
102
  gr.Slider(0, 1, 0.95, label="采样率"),
103
  ],