Spaces:
Build error
Build error
| # ============================================================================== | |
| # app.py - GRADIO WEB UI FOR HF SPACES | |
| # ============================================================================== | |
| import gradio as gr | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| print("--- Setting up Model and Tokenizer for Gradio Space ---") | |
| # --- Configuration --- | |
| base_model_name = "Qwen/Qwen-1_8B-Chat" | |
| adapter_model_name = "jinv2/qwen-1.8b-chat-lora-stock-quant-edu" # YOUR HF MODEL ID | |
| # Device will be determined by the Space hardware (CPU or GPU) | |
| # device_map="auto" will handle this. | |
| print(f"Base model: {base_model_name}") | |
| print(f"Adapter model: {adapter_model_name}") | |
| # --- Load Tokenizer --- | |
| print(f"Loading tokenizer for {base_model_name}...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| qwen_pad_eos_token_id = 151643 | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = qwen_pad_eos_token_id | |
| if tokenizer.eos_token_id is None: | |
| tokenizer.eos_token_id = qwen_pad_eos_token_id | |
| print("Tokenizer configured.") | |
| # --- Load Base Model --- | |
| print(f"Loading base model {base_model_name}...") | |
| # Use torch_dtype="auto" to let transformers pick best dtype for the hardware | |
| # On CPU Spaces, float16 might not be ideal. | |
| # On GPU Spaces, device_map="auto" should place it on GPU. | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype="auto", # Let transformers decide based on hardware | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("Base model loaded.") | |
| # --- Load LoRA Adapter --- | |
| print(f"Loading LoRA adapter from {adapter_model_name}...") | |
| try: | |
| # Ensure the model is on the correct device before applying PEFT if device_map wasn't "auto" | |
| # model = PeftModel.from_pretrained(base_model.to(device), adapter_model_name) # If base_model not on auto | |
| model = PeftModel.from_pretrained(base_model, adapter_model_name) | |
| print(f"LoRA adapter '{adapter_model_name}' loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading LoRA adapter: {e}. Using base model as fallback.") | |
| model = base_model | |
| model = model.eval() | |
| print("Model is ready for inference.") | |
| # --- Define the prediction function for Gradio --- | |
| def predict(message, chat_history_tuples): | |
| print(f"\nUser message: {message}") | |
| # Convert Gradio's chat_history (list of lists) to model.chat history format (list of tuples) | |
| model_chat_history = [] | |
| if chat_history_tuples: # chat_history_tuples is list of [user, bot] pairs | |
| for user_turn, bot_turn in chat_history_tuples: | |
| if user_turn is not None and bot_turn is not None: # Ensure both are present | |
| model_chat_history.append((user_turn, bot_turn)) | |
| elif user_turn is not None: # Handle cases where bot_turn might be None initially | |
| model_chat_history.append((user_turn, "")) | |
| system_prompt = "你是一个专业的股票量化投教助手,请用通俗易懂的大白话回答问题,并尽量保持友好和耐心。" | |
| print(f"Model history passed to model.chat: {model_chat_history}") | |
| # The model.chat method returns (response, history_tuples) | |
| # where history_tuples is the updated list of (query, response) tuples. | |
| response_text, updated_model_history = model.chat( | |
| tokenizer, | |
| message, | |
| history=model_chat_history, | |
| system=system_prompt | |
| ) | |
| print(f"Model response: {response_text}") | |
| # Gradio Chatbot expects a list of lists: [[user_msg1, bot_msg1], [user_msg2, bot_msg2]] | |
| # We can reconstruct this from updated_model_history | |
| gradio_display_history = [] | |
| for q, r in updated_model_history: | |
| gradio_display_history.append([q, r]) | |
| return "", gradio_display_history # Return empty string to clear textbox, and updated history list of lists | |
| # --- Create Gradio Interface --- | |
| print("\n--- Building Gradio Interface ---") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # 股票量化投教助手 (Stock Quant Educational Assistant) 📈💡 | |
| 微调自 **{base_model_name}** ,使用了 LoRA 适配器 **{adapter_model_name}**. | |
| 由 **天算AI科技研发实验室 (Natural Algorithm AI R&D Lab) - jinv2** 开发。 | |
| *这是一个实验性模型,知识有限,请勿用于真实投资决策。* | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(label="聊天窗口", height=550, avatar_images=("/kaggle/input/user-image/USER.png", "/kaggle/input/bot-image/ROBOT.png")) # Optional: provide paths to user/bot avatar images if you upload them | |
| msg = gr.Textbox(label="你的问题:", placeholder="例如:什么是移动平均线?", lines=2) | |
| with gr.Row(): | |
| submit_button = gr.Button("发送 (Send)", variant="primary") | |
| clear_button = gr.Button("清除聊天记录 (Clear)") | |
| # When message is submitted (Enter or button click) | |
| # The inputs to `predict` are `msg` and `chatbot` (which is the history) | |
| # The outputs of `predict` are `msg` (to clear it) and `chatbot` (to update it) | |
| msg.submit(predict, [msg, chatbot], [msg, chatbot]) | |
| submit_button.click(predict, [msg, chatbot], [msg, chatbot]) | |
| clear_button.click(lambda: (None, None, None), None, [msg, chatbot], queue=False) | |
| gr.Examples( | |
| examples=[ | |
| "什么是K线图?", | |
| "MACD指标的金叉代表什么?", | |
| "网格交易有什么风险?", | |
| "RSI指标怎么看超买超卖?", | |
| "量化交易能赚钱吗?" | |
| ], | |
| inputs=msg, | |
| label="示例问题 (Example Questions)" | |
| ) | |
| # --- Launch the interface --- | |
| print("--- Launching Gradio Interface for HF Space ---") | |
| # For HF Spaces, queue().launch() is usually sufficient. | |
| # Share=True is not needed as Spaces provides its own URL. | |
| # In_browser=True might be useful for local testing. | |
| demo.queue().launch(debug=False) # debug=True for more logs if needed | |