jinv2's picture
Upload 2 files
d1876d2 verified
# ==============================================================================
# app.py - GRADIO WEB UI FOR HF SPACES
# ==============================================================================
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print("--- Setting up Model and Tokenizer for Gradio Space ---")
# --- Configuration ---
base_model_name = "Qwen/Qwen-1_8B-Chat"
adapter_model_name = "jinv2/qwen-1.8b-chat-lora-stock-quant-edu" # YOUR HF MODEL ID
# Device will be determined by the Space hardware (CPU or GPU)
# device_map="auto" will handle this.
print(f"Base model: {base_model_name}")
print(f"Adapter model: {adapter_model_name}")
# --- Load Tokenizer ---
print(f"Loading tokenizer for {base_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
qwen_pad_eos_token_id = 151643
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = qwen_pad_eos_token_id
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = qwen_pad_eos_token_id
print("Tokenizer configured.")
# --- Load Base Model ---
print(f"Loading base model {base_model_name}...")
# Use torch_dtype="auto" to let transformers pick best dtype for the hardware
# On CPU Spaces, float16 might not be ideal.
# On GPU Spaces, device_map="auto" should place it on GPU.
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype="auto", # Let transformers decide based on hardware
device_map="auto",
trust_remote_code=True
)
print("Base model loaded.")
# --- Load LoRA Adapter ---
print(f"Loading LoRA adapter from {adapter_model_name}...")
try:
# Ensure the model is on the correct device before applying PEFT if device_map wasn't "auto"
# model = PeftModel.from_pretrained(base_model.to(device), adapter_model_name) # If base_model not on auto
model = PeftModel.from_pretrained(base_model, adapter_model_name)
print(f"LoRA adapter '{adapter_model_name}' loaded successfully.")
except Exception as e:
print(f"Error loading LoRA adapter: {e}. Using base model as fallback.")
model = base_model
model = model.eval()
print("Model is ready for inference.")
# --- Define the prediction function for Gradio ---
def predict(message, chat_history_tuples):
print(f"\nUser message: {message}")
# Convert Gradio's chat_history (list of lists) to model.chat history format (list of tuples)
model_chat_history = []
if chat_history_tuples: # chat_history_tuples is list of [user, bot] pairs
for user_turn, bot_turn in chat_history_tuples:
if user_turn is not None and bot_turn is not None: # Ensure both are present
model_chat_history.append((user_turn, bot_turn))
elif user_turn is not None: # Handle cases where bot_turn might be None initially
model_chat_history.append((user_turn, ""))
system_prompt = "你是一个专业的股票量化投教助手,请用通俗易懂的大白话回答问题,并尽量保持友好和耐心。"
print(f"Model history passed to model.chat: {model_chat_history}")
# The model.chat method returns (response, history_tuples)
# where history_tuples is the updated list of (query, response) tuples.
response_text, updated_model_history = model.chat(
tokenizer,
message,
history=model_chat_history,
system=system_prompt
)
print(f"Model response: {response_text}")
# Gradio Chatbot expects a list of lists: [[user_msg1, bot_msg1], [user_msg2, bot_msg2]]
# We can reconstruct this from updated_model_history
gradio_display_history = []
for q, r in updated_model_history:
gradio_display_history.append([q, r])
return "", gradio_display_history # Return empty string to clear textbox, and updated history list of lists
# --- Create Gradio Interface ---
print("\n--- Building Gradio Interface ---")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 股票量化投教助手 (Stock Quant Educational Assistant) 📈💡
微调自 **{base_model_name}** ,使用了 LoRA 适配器 **{adapter_model_name}**.
由 **天算AI科技研发实验室 (Natural Algorithm AI R&D Lab) - jinv2** 开发。
*这是一个实验性模型,知识有限,请勿用于真实投资决策。*
"""
)
chatbot = gr.Chatbot(label="聊天窗口", height=550, avatar_images=("/kaggle/input/user-image/USER.png", "/kaggle/input/bot-image/ROBOT.png")) # Optional: provide paths to user/bot avatar images if you upload them
msg = gr.Textbox(label="你的问题:", placeholder="例如:什么是移动平均线?", lines=2)
with gr.Row():
submit_button = gr.Button("发送 (Send)", variant="primary")
clear_button = gr.Button("清除聊天记录 (Clear)")
# When message is submitted (Enter or button click)
# The inputs to `predict` are `msg` and `chatbot` (which is the history)
# The outputs of `predict` are `msg` (to clear it) and `chatbot` (to update it)
msg.submit(predict, [msg, chatbot], [msg, chatbot])
submit_button.click(predict, [msg, chatbot], [msg, chatbot])
clear_button.click(lambda: (None, None, None), None, [msg, chatbot], queue=False)
gr.Examples(
examples=[
"什么是K线图?",
"MACD指标的金叉代表什么?",
"网格交易有什么风险?",
"RSI指标怎么看超买超卖?",
"量化交易能赚钱吗?"
],
inputs=msg,
label="示例问题 (Example Questions)"
)
# --- Launch the interface ---
print("--- Launching Gradio Interface for HF Space ---")
# For HF Spaces, queue().launch() is usually sufficient.
# Share=True is not needed as Spaces provides its own URL.
# In_browser=True might be useful for local testing.
demo.queue().launch(debug=False) # debug=True for more logs if needed