|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from peft import PeftModel |
|
|
import os |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast |
|
|
|
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
try: |
|
|
base_model = "meta-llama/Meta-Llama-3-8B" |
|
|
peft_model = "FinGPT/fingpt-mt_llama3-8b_lora" |
|
|
tokenizer = LlamaTokenizerFast.from_pretrained(base_model, trust_remote_code=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model = LlamaForCausalLM.from_pretrained(base_model, trust_remote_code=True, device_map="cuda:0") |
|
|
model = PeftModel.from_pretrained(model, peft_model) |
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print("\n" + "=" * 50) |
|
|
print("❌ 模型加载失败!") |
|
|
print(f"错误信息: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def chat(message, history): |
|
|
""" |
|
|
聊天函数,处理用户消息并返回模型响应 |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
return "❌ 模型未正确加载,请检查Spaces日志获取详细错误信息。" |
|
|
|
|
|
try: |
|
|
|
|
|
conversation = [] |
|
|
for user_msg, bot_msg in history: |
|
|
conversation.append(f"User: {user_msg}") |
|
|
conversation.append(f"Assistant: {bot_msg}") |
|
|
|
|
|
conversation.append(f"User: {message}") |
|
|
conversation.append("Assistant:") |
|
|
|
|
|
prompt = "\n".join(conversation) |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "Assistant:" in response: |
|
|
response = response.split("Assistant:")[-1].strip() |
|
|
|
|
|
return response |
|
|
except Exception as e: |
|
|
return f"❌ 生成回复时出错: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🤖 FinGPT Chatbot |
|
|
|
|
|
这是一个基于 **FinGPT/fingpt-forecaster_dow30_llama2-7b_lora** 模型的金融对话助手。 |
|
|
|
|
|
您可以询问关于金融市场、投资、经济分析等问题。 |
|
|
""" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
label="聊天记录", |
|
|
height=500, |
|
|
bubble_full_width=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
label="输入您的消息", |
|
|
placeholder="在这里输入您的问题...", |
|
|
scale=4 |
|
|
) |
|
|
submit = gr.Button("发送", scale=1, variant="primary") |
|
|
|
|
|
clear = gr.Button("清空对话历史") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"什么是量化宽松政策?", |
|
|
"如何评估一只股票的价值?", |
|
|
"请解释一下技术分析中的MACD指标", |
|
|
"当前市场环境下应该如何配置资产?" |
|
|
], |
|
|
inputs=msg |
|
|
) |
|
|
|
|
|
|
|
|
def user_message(user_msg, history): |
|
|
return "", history + [[user_msg, None]] |
|
|
|
|
|
def bot_message(history): |
|
|
user_msg = history[-1][0] |
|
|
bot_response = chat(user_msg, history[:-1]) |
|
|
history[-1][1] = bot_response |
|
|
return history |
|
|
|
|
|
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_message, chatbot, chatbot |
|
|
) |
|
|
submit.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_message, chatbot, chatbot |
|
|
) |
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|