File size: 4,851 Bytes
3e90f7a 0c5d74c 981ad89 237271f 3e90f7a 0c5d74c f43468a 0c5d74c aa2a458 13e78d2 aa2a458 13e78d2 0c5d74c f43468a 0c5d74c 3e90f7a b1d5984 f43468a 3e90f7a 0c5d74c 3e90f7a 981ad89 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a 0c5d74c 3e90f7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
import spaces
import torch
from peft import PeftModel
import os
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast
# 获取HF token(Spaces会自动提供)
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
model = None
tokenizer = None
try:
base_model = "meta-llama/Meta-Llama-3-8B"
peft_model = "FinGPT/fingpt-mt_llama3-8b_lora"
tokenizer = LlamaTokenizerFast.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(base_model, trust_remote_code=True, device_map="cuda:0")
model = PeftModel.from_pretrained(model, peft_model)
model = model.eval()
# model = AutoModel.from_pretrained("tuananhle/fingpt-forecaster_dow30_qwen3-8b_lora_250814_v3", dtype="auto")
# base_model = AutoModelForCausalLM.from_pretrained(
# 'meta-llama/Llama-2-7b-chat-hf',
# trust_remote_code=True,
# device_map="auto",
# torch_dtype=torch.float16, # optional if you have enough VRAM
# token=hf_token,
# )
# tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf',
# trust_remote_code=True,
# token=hf_token,
# )
# model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
# model = model.eval()
except Exception as e:
print("\n" + "=" * 50)
print("❌ 模型加载失败!")
print(f"错误信息: {e}")
raise
@spaces.GPU
def chat(message, history):
"""
聊天函数,处理用户消息并返回模型响应
"""
if model is None or tokenizer is None:
return "❌ 模型未正确加载,请检查Spaces日志获取详细错误信息。"
try:
# 构建对话历史
conversation = []
for user_msg, bot_msg in history:
conversation.append(f"User: {user_msg}")
conversation.append(f"Assistant: {bot_msg}")
conversation.append(f"User: {message}")
conversation.append("Assistant:")
prompt = "\n".join(conversation)
# 编码输入
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# 生成响应
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取助手的回复
if "Assistant:" in response:
response = response.split("Assistant:")[-1].strip()
return response
except Exception as e:
return f"❌ 生成回复时出错: {str(e)}"
# 创建Gradio Chatbot界面
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🤖 FinGPT Chatbot
这是一个基于 **FinGPT/fingpt-forecaster_dow30_llama2-7b_lora** 模型的金融对话助手。
您可以询问关于金融市场、投资、经济分析等问题。
"""
)
chatbot = gr.Chatbot(
label="聊天记录",
height=500,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
label="输入您的消息",
placeholder="在这里输入您的问题...",
scale=4
)
submit = gr.Button("发送", scale=1, variant="primary")
clear = gr.Button("清空对话历史")
gr.Examples(
examples=[
"什么是量化宽松政策?",
"如何评估一只股票的价值?",
"请解释一下技术分析中的MACD指标",
"当前市场环境下应该如何配置资产?"
],
inputs=msg
)
# 事件处理
def user_message(user_msg, history):
return "", history + [[user_msg, None]]
def bot_message(history):
user_msg = history[-1][0]
bot_response = chat(user_msg, history[:-1])
history[-1][1] = bot_response
return history
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_message, chatbot, chatbot
)
submit.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_message, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()
|