| | import backtrader as bt |
| | import yfinance as yf |
| | import re |
| | import google.generativeai as genai |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| | import numpy as np |
| | from datetime import datetime |
| |
|
| |
|
| | |
| | MODEL_ID = "models/gemini-2.0-flash" |
| | API_KEY = "ENTER API_KEY HERE" |
| | genai.configure(api_key=API_KEY) |
| | model = genai.GenerativeModel(MODEL_ID) |
| |
|
| |
|
| |
|
| |
|
| | |
| | def get_strategy_code_from_gemini(user_input: str): |
| | prompt = f""" |
| | You are a financial assistant that writes Backtrader strategies. |
| | |
| | |
| | Take this user request: "{user_input}" |
| | Generate a complete Backtrader strategy (Python) using bt.Strategy with these requirements: |
| | 1. Initial capital is $100 |
| | 2. Only whole share orders (no fractional shares) |
| | 3. Must include both buy and sell logic |
| | 4. Strategy should have at least one indicator (like SMA, RSI, etc.) |
| | 5. Include proper order management (check for pending orders) |
| | 6. Include logging of trades |
| | 7. Output only valid Python code (no explanation or markdown formatting) |
| | """ |
| | response = model.generate_content(prompt) |
| | strategy_code = response.text.strip() |
| | strategy_code = strategy_code.replace("```python", "").replace("```", "") |
| | return strategy_code |
| |
|
| |
|
| | |
| | def create_strategy_from_code(code_string: str): |
| | local_scope = {} |
| | try: |
| | exec(code_string, globals(), local_scope) |
| | for obj in local_scope.values(): |
| | if isinstance(obj, type) and issubclass(obj, bt.Strategy): |
| | return obj |
| | raise ValueError("No valid strategy class found in Gemini output.") |
| | except Exception as e: |
| | raise ValueError(f"Error creating strategy from code: {str(e)}") |
| |
|
| |
|
| | |
| | def extract_strategy_only(code_str: str) -> str: |
| | main_block_start = code_str.find("if __name__ == '__main__':") |
| | return code_str[:main_block_start].strip() if main_block_start != -1 else code_str.strip() |
| |
|
| |
|
| | |
| | def full_workflow(user_input: str): |
| | |
| | ticker_match = re.search(r'\b([A-Z]{2,5})\b', user_input) |
| | ticker = ticker_match.group(1) if ticker_match else "AAPL" |
| |
|
| |
|
| | dates = re.findall(r'(\d{4}-\d{2}-\d{2})', user_input) |
| | start_date = dates[0] if len(dates) > 0 else '2022-01-01' |
| | end_date = dates[1] if len(dates) > 1 else '2023-01-01' |
| |
|
| |
|
| | print("\n馃 Generating strategy...") |
| | strategy_code = get_strategy_code_from_gemini(user_input) |
| |
|
| |
|
| | print(extract_strategy_only(strategy_code)) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | user_input = "Create RSI strategy for MSFT, buy below 30 sell above 70, from 2021-01-01 to 2022-12-31" |
| | full_workflow(user_input) |
| |
|
| | generate_answer = get_strategy_code_from_gemini |
| |
|