jinv2 commited on
Commit
d1876d2
·
verified ·
1 Parent(s): bf5a721

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +137 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # app.py - GRADIO WEB UI FOR HF SPACES
3
+ # ==============================================================================
4
+ import gradio as gr
5
+ from peft import PeftModel
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import torch
8
+
9
+ print("--- Setting up Model and Tokenizer for Gradio Space ---")
10
+
11
+ # --- Configuration ---
12
+ base_model_name = "Qwen/Qwen-1_8B-Chat"
13
+ adapter_model_name = "jinv2/qwen-1.8b-chat-lora-stock-quant-edu" # YOUR HF MODEL ID
14
+ # Device will be determined by the Space hardware (CPU or GPU)
15
+ # device_map="auto" will handle this.
16
+ print(f"Base model: {base_model_name}")
17
+ print(f"Adapter model: {adapter_model_name}")
18
+
19
+ # --- Load Tokenizer ---
20
+ print(f"Loading tokenizer for {base_model_name}...")
21
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
22
+ qwen_pad_eos_token_id = 151643
23
+ if tokenizer.pad_token_id is None:
24
+ tokenizer.pad_token_id = qwen_pad_eos_token_id
25
+ if tokenizer.eos_token_id is None:
26
+ tokenizer.eos_token_id = qwen_pad_eos_token_id
27
+ print("Tokenizer configured.")
28
+
29
+
30
+ # --- Load Base Model ---
31
+ print(f"Loading base model {base_model_name}...")
32
+ # Use torch_dtype="auto" to let transformers pick best dtype for the hardware
33
+ # On CPU Spaces, float16 might not be ideal.
34
+ # On GPU Spaces, device_map="auto" should place it on GPU.
35
+ base_model = AutoModelForCausalLM.from_pretrained(
36
+ base_model_name,
37
+ torch_dtype="auto", # Let transformers decide based on hardware
38
+ device_map="auto",
39
+ trust_remote_code=True
40
+ )
41
+ print("Base model loaded.")
42
+
43
+ # --- Load LoRA Adapter ---
44
+ print(f"Loading LoRA adapter from {adapter_model_name}...")
45
+ try:
46
+ # Ensure the model is on the correct device before applying PEFT if device_map wasn't "auto"
47
+ # model = PeftModel.from_pretrained(base_model.to(device), adapter_model_name) # If base_model not on auto
48
+ model = PeftModel.from_pretrained(base_model, adapter_model_name)
49
+ print(f"LoRA adapter '{adapter_model_name}' loaded successfully.")
50
+ except Exception as e:
51
+ print(f"Error loading LoRA adapter: {e}. Using base model as fallback.")
52
+ model = base_model
53
+
54
+ model = model.eval()
55
+ print("Model is ready for inference.")
56
+
57
+ # --- Define the prediction function for Gradio ---
58
+ def predict(message, chat_history_tuples):
59
+ print(f"\nUser message: {message}")
60
+
61
+ # Convert Gradio's chat_history (list of lists) to model.chat history format (list of tuples)
62
+ model_chat_history = []
63
+ if chat_history_tuples: # chat_history_tuples is list of [user, bot] pairs
64
+ for user_turn, bot_turn in chat_history_tuples:
65
+ if user_turn is not None and bot_turn is not None: # Ensure both are present
66
+ model_chat_history.append((user_turn, bot_turn))
67
+ elif user_turn is not None: # Handle cases where bot_turn might be None initially
68
+ model_chat_history.append((user_turn, ""))
69
+
70
+
71
+ system_prompt = "你是一个专业的股票量化投教助手,请用通俗易懂的大白话回答问题,并尽量保持友好和耐心。"
72
+
73
+ print(f"Model history passed to model.chat: {model_chat_history}")
74
+ # The model.chat method returns (response, history_tuples)
75
+ # where history_tuples is the updated list of (query, response) tuples.
76
+ response_text, updated_model_history = model.chat(
77
+ tokenizer,
78
+ message,
79
+ history=model_chat_history,
80
+ system=system_prompt
81
+ )
82
+ print(f"Model response: {response_text}")
83
+
84
+ # Gradio Chatbot expects a list of lists: [[user_msg1, bot_msg1], [user_msg2, bot_msg2]]
85
+ # We can reconstruct this from updated_model_history
86
+ gradio_display_history = []
87
+ for q, r in updated_model_history:
88
+ gradio_display_history.append([q, r])
89
+
90
+ return "", gradio_display_history # Return empty string to clear textbox, and updated history list of lists
91
+
92
+
93
+ # --- Create Gradio Interface ---
94
+ print("\n--- Building Gradio Interface ---")
95
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
96
+ gr.Markdown(
97
+ f"""
98
+ # 股票量化投教助手 (Stock Quant Educational Assistant) 📈💡
99
+ 微调自 **{base_model_name}** ,使用了 LoRA 适配器 **{adapter_model_name}**.
100
+ 由 **天算AI科技研发实验室 (Natural Algorithm AI R&D Lab) - jinv2** 开发。
101
+ *这是一个实验性模型,知识有限,请勿用于真实投资决策。*
102
+ """
103
+ )
104
+ chatbot = gr.Chatbot(label="聊天窗口", height=550, avatar_images=("/kaggle/input/user-image/USER.png", "/kaggle/input/bot-image/ROBOT.png")) # Optional: provide paths to user/bot avatar images if you upload them
105
+ msg = gr.Textbox(label="你的问题:", placeholder="例如:什么是移动平均线?", lines=2)
106
+
107
+ with gr.Row():
108
+ submit_button = gr.Button("发送 (Send)", variant="primary")
109
+ clear_button = gr.Button("清除聊天记录 (Clear)")
110
+
111
+ # When message is submitted (Enter or button click)
112
+ # The inputs to `predict` are `msg` and `chatbot` (which is the history)
113
+ # The outputs of `predict` are `msg` (to clear it) and `chatbot` (to update it)
114
+ msg.submit(predict, [msg, chatbot], [msg, chatbot])
115
+ submit_button.click(predict, [msg, chatbot], [msg, chatbot])
116
+
117
+ clear_button.click(lambda: (None, None, None), None, [msg, chatbot], queue=False)
118
+
119
+
120
+ gr.Examples(
121
+ examples=[
122
+ "什么是K线图?",
123
+ "MACD指标的金叉代表什么?",
124
+ "网格交易有什么风险?",
125
+ "RSI指标怎么看超买超卖?",
126
+ "量化交易能赚钱吗?"
127
+ ],
128
+ inputs=msg,
129
+ label="示例问题 (Example Questions)"
130
+ )
131
+
132
+ # --- Launch the interface ---
133
+ print("--- Launching Gradio Interface for HF Space ---")
134
+ # For HF Spaces, queue().launch() is usually sufficient.
135
+ # Share=True is not needed as Spaces provides its own URL.
136
+ # In_browser=True might be useful for local testing.
137
+ demo.queue().launch(debug=False) # debug=True for more logs if needed
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers>=4.40.1
3
+ peft>=0.9.0
4
+ torch>=2.0.0
5
+ accelerate>=0.28.0
6
+ sentencepiece
7
+ # bitsandbytes # 如果你解决了GPU编译问题并想用QLoRA,可以加上,否则移除
8
+ # For Qwen specific tokenizer or model files if not automatically handled by transformers
9
+ # tiktoken (Qwen-1.8B tokenizer might use this, or sentencepiece)
10
+ # transformers_stream_generator (Qwen-1.8B might need this for streaming)