FinGPT / app.py
jolchmo's picture
fix
237271f
raw
history blame
4.85 kB
import gradio as gr
import spaces
import torch
from peft import PeftModel
import os
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast
# 获取HF token(Spaces会自动提供)
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
model = None
tokenizer = None
try:
base_model = "meta-llama/Meta-Llama-3-8B"
peft_model = "FinGPT/fingpt-mt_llama3-8b_lora"
tokenizer = LlamaTokenizerFast.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(base_model, trust_remote_code=True, device_map="cuda:0")
model = PeftModel.from_pretrained(model, peft_model)
model = model.eval()
# model = AutoModel.from_pretrained("tuananhle/fingpt-forecaster_dow30_qwen3-8b_lora_250814_v3", dtype="auto")
# base_model = AutoModelForCausalLM.from_pretrained(
# 'meta-llama/Llama-2-7b-chat-hf',
# trust_remote_code=True,
# device_map="auto",
# torch_dtype=torch.float16, # optional if you have enough VRAM
# token=hf_token,
# )
# tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf',
# trust_remote_code=True,
# token=hf_token,
# )
# model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
# model = model.eval()
except Exception as e:
print("\n" + "=" * 50)
print("❌ 模型加载失败!")
print(f"错误信息: {e}")
raise
@spaces.GPU
def chat(message, history):
"""
聊天函数,处理用户消息并返回模型响应
"""
if model is None or tokenizer is None:
return "❌ 模型未正确加载,请检查Spaces日志获取详细错误信息。"
try:
# 构建对话历史
conversation = []
for user_msg, bot_msg in history:
conversation.append(f"User: {user_msg}")
conversation.append(f"Assistant: {bot_msg}")
conversation.append(f"User: {message}")
conversation.append("Assistant:")
prompt = "\n".join(conversation)
# 编码输入
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# 生成响应
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取助手的回复
if "Assistant:" in response:
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
return f"❌ 生成回复时出错: {str(e)}"
# 创建Gradio Chatbot界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🤖 FinGPT Chatbot
这是一个基于 **FinGPT/fingpt-forecaster_dow30_llama2-7b_lora** 模型的金融对话助手。
您可以询问关于金融市场、投资、经济分析等问题。
"""
)
chatbot = gr.Chatbot(
label="聊天记录",
height=500,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
label="输入您的消息",
placeholder="在这里输入您的问题...",
scale=4
)
submit = gr.Button("发送", scale=1, variant="primary")
clear = gr.Button("清空对话历史")
gr.Examples(
examples=[
"什么是量化宽松政策?",
"如何评估一只股票的价值?",
"请解释一下技术分析中的MACD指标",
"当前市场环境下应该如何配置资产?"
],
inputs=msg
)
# 事件处理
def user_message(user_msg, history):
return "", history + [[user_msg, None]]
def bot_message(history):
user_msg = history[-1][0]
bot_response = chat(user_msg, history[:-1])
history[-1][1] = bot_response
return history
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_message, chatbot, chatbot
)
submit.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_message, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()