jolchmo commited on
Commit
0c5d74c
·
1 Parent(s): 3e90f7a
Files changed (2) hide show
  1. README.md +8 -0
  2. app.py +41 -25
README.md CHANGED
@@ -8,12 +8,20 @@ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  # 🤖 FinGPT Chatbot
14
 
15
  这是一个基于 **FinGPT/fingpt-mt_llama3-8b_lora** 模型的金融对话助手Spaces应用。
16
 
 
 
 
 
 
 
 
17
  ## 功能特性
18
 
19
  - 💬 实时对话:支持多轮对话,保持上下文
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ hf_oauth: true
12
  ---
13
 
14
  # 🤖 FinGPT Chatbot
15
 
16
  这是一个基于 **FinGPT/fingpt-mt_llama3-8b_lora** 模型的金融对话助手Spaces应用。
17
 
18
+ ## ⚠️ 重要配置
19
+
20
+ 由于使用了Llama 3基础模型,需要在Spaces设置中配置访问权限:
21
+
22
+ 1. 确保你的HF账号已经获得 [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) 的访问权限
23
+ 2. 在Spaces的Settings中添加 `HF_TOKEN` secret(使用你的Hugging Face访问令牌)
24
+
25
  ## 功能特性
26
 
27
  - 💬 实时对话:支持多轮对话,保持上下文
app.py CHANGED
@@ -3,26 +3,41 @@ import spaces
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
 
6
 
7
  # 加载模型和tokenizer
8
  model_name = "meta-llama/Meta-Llama-3-8B"
9
  adapter_name = "FinGPT/fingpt-mt_llama3-8b_lora"
10
 
 
 
 
11
  print("正在加载模型...")
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
14
 
15
- base_model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- torch_dtype=torch.float16,
18
- device_map="auto",
19
- trust_remote_code=True
20
- )
 
21
 
22
- model = PeftModel.from_pretrained(base_model, adapter_name)
23
- model = model.eval()
 
 
 
 
 
 
24
 
25
- print("模型加载完成!")
26
 
27
  @spaces.GPU
28
  def chat(message, history):
@@ -34,16 +49,16 @@ def chat(message, history):
34
  for user_msg, bot_msg in history:
35
  conversation.append(f"User: {user_msg}")
36
  conversation.append(f"Assistant: {bot_msg}")
37
-
38
  conversation.append(f"User: {message}")
39
  conversation.append("Assistant:")
40
-
41
  prompt = "\n".join(conversation)
42
-
43
  # 编码输入
44
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
45
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
46
-
47
  # 生成响应
48
  with torch.no_grad():
49
  outputs = model.generate(
@@ -54,16 +69,17 @@ def chat(message, history):
54
  do_sample=True,
55
  pad_token_id=tokenizer.eos_token_id
56
  )
57
-
58
  # 解码输出
59
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
-
61
  # 提取助手的回复
62
  if "Assistant:" in response:
63
  response = response.split("Assistant:")[-1].strip()
64
-
65
  return response
66
 
 
67
  # 创建Gradio Chatbot界面
68
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
  gr.Markdown(
@@ -75,13 +91,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
  您可以询问关于金融市场、投资、经济分析等问题。
76
  """
77
  )
78
-
79
  chatbot = gr.Chatbot(
80
  label="聊天记录",
81
  height=500,
82
  bubble_full_width=False
83
  )
84
-
85
  with gr.Row():
86
  msg = gr.Textbox(
87
  label="输入您的消息",
@@ -89,9 +105,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
89
  scale=4
90
  )
91
  submit = gr.Button("发送", scale=1, variant="primary")
92
-
93
  clear = gr.Button("清空对话历史")
94
-
95
  gr.Examples(
96
  examples=[
97
  "什么是量化宽松政策?",
@@ -101,17 +117,17 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
101
  ],
102
  inputs=msg
103
  )
104
-
105
  # 事件处理
106
  def user_message(user_msg, history):
107
  return "", history + [[user_msg, None]]
108
-
109
  def bot_message(history):
110
  user_msg = history[-1][0]
111
  bot_response = chat(user_msg, history[:-1])
112
  history[-1][1] = bot_response
113
  return history
114
-
115
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
116
  bot_message, chatbot, chatbot
117
  )
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
+ import os
7
 
8
  # 加载模型和tokenizer
9
  model_name = "meta-llama/Meta-Llama-3-8B"
10
  adapter_name = "FinGPT/fingpt-mt_llama3-8b_lora"
11
 
12
+ # 获取HF token(Spaces会自动提供)
13
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
14
+
15
  print("正在加载模型...")
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ token=hf_token
21
+ )
22
+ tokenizer.pad_token = tokenizer.eos_token
23
 
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16,
27
+ device_map="auto",
28
+ trust_remote_code=True,
29
+ token=hf_token
30
+ )
31
 
32
+ model = PeftModel.from_pretrained(base_model, adapter_name)
33
+ model = model.eval()
34
+
35
+ print("模型加载完成!")
36
+ except Exception as e:
37
+ print(f"模型加载错误: {e}")
38
+ print("请确保在Spaces设置中添加了HF_TOKEN")
39
+ raise
40
 
 
41
 
42
  @spaces.GPU
43
  def chat(message, history):
 
49
  for user_msg, bot_msg in history:
50
  conversation.append(f"User: {user_msg}")
51
  conversation.append(f"Assistant: {bot_msg}")
52
+
53
  conversation.append(f"User: {message}")
54
  conversation.append("Assistant:")
55
+
56
  prompt = "\n".join(conversation)
57
+
58
  # 编码输入
59
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
60
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
61
+
62
  # 生成响应
63
  with torch.no_grad():
64
  outputs = model.generate(
 
69
  do_sample=True,
70
  pad_token_id=tokenizer.eos_token_id
71
  )
72
+
73
  # 解码输出
74
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
  # 提取助手的回复
77
  if "Assistant:" in response:
78
  response = response.split("Assistant:")[-1].strip()
79
+
80
  return response
81
 
82
+
83
  # 创建Gradio Chatbot界面
84
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
  gr.Markdown(
 
91
  您可以询问关于金融市场、投资、经济分析等问题。
92
  """
93
  )
94
+
95
  chatbot = gr.Chatbot(
96
  label="聊天记录",
97
  height=500,
98
  bubble_full_width=False
99
  )
100
+
101
  with gr.Row():
102
  msg = gr.Textbox(
103
  label="输入您的消息",
 
105
  scale=4
106
  )
107
  submit = gr.Button("发送", scale=1, variant="primary")
108
+
109
  clear = gr.Button("清空对话历史")
110
+
111
  gr.Examples(
112
  examples=[
113
  "什么是量化宽松政策?",
 
117
  ],
118
  inputs=msg
119
  )
120
+
121
  # 事件处理
122
  def user_message(user_msg, history):
123
  return "", history + [[user_msg, None]]
124
+
125
  def bot_message(history):
126
  user_msg = history[-1][0]
127
  bot_response = chat(user_msg, history[:-1])
128
  history[-1][1] = bot_response
129
  return history
130
+
131
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
132
  bot_message, chatbot, chatbot
133
  )