| """ |
| Async Stock Price Predictor using Amazon Chronos T5-Small Time Series Model |
| |
| Required installations: |
| pip install chronos-forecasting yfinance torch numpy pandas aiohttp asyncio |
| |
| Usage: |
| python stock_predictor.py |
| """ |
|
|
| import yfinance as yf |
| import torch |
| import numpy as np |
| from chronos import ChronosPipeline |
| import pandas as pd |
| import logging |
| import asyncio |
| import aiohttp |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import Optional, Tuple, List, Dict |
| from datetime import datetime |
| import warnings |
| import time |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AsyncStockPredictor: |
| """ |
| An async stock price predictor using Amazon Chronos T5 time series model. |
| |
| This class fetches historical stock data asynchronously and uses the Chronos model |
| to predict future stock prices and movement trends with concurrent processing. |
| """ |
|
|
| def __init__(self, model_name: str = "amazon/chronos-t5-small", max_workers: int = 4): |
| """ |
| Initialize the async stock predictor with Chronos model. |
| |
| Args: |
| model_name: Name of the Chronos model to use |
| max_workers: Maximum number of worker threads for CPU-intensive tasks |
| """ |
| self.model_name = model_name |
| self.max_workers = max_workers |
| self.executor = ThreadPoolExecutor(max_workers=max_workers) |
| self.pipeline = None |
|
|
| async def initialize(self): |
| """Initialize the model asynchronously.""" |
| try: |
| logger.info(f"Loading Chronos model: {self.model_name}") |
| |
| self.pipeline = await asyncio.get_event_loop().run_in_executor( |
| self.executor, self._load_model |
| ) |
| logger.info("Chronos model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading model: {e}") |
| raise |
|
|
| def _load_model(self): |
| """Load the Chronos model (CPU intensive, runs in thread pool).""" |
| try: |
| return ChronosPipeline.from_pretrained( |
| self.model_name, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
| except Exception as e: |
| logger.warning(f"Failed to load with optimized settings: {e}") |
| logger.info("Attempting to load with default settings...") |
| return ChronosPipeline.from_pretrained(self.model_name) |
|
|
| async def fetch_prices_async(self, ticker: str, period: str = "6mo", interval: str = "1d") -> Optional[ |
| pd.DataFrame]: |
| """ |
| Fetch historical stock price data asynchronously. |
| |
| Args: |
| ticker: Stock ticker symbol (e.g., 'AAPL') |
| period: Time period for data (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max) |
| interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo) |
| |
| Returns: |
| DataFrame with OHLCV data or None if error occurs |
| """ |
| try: |
| logger.info(f"Fetching data for {ticker}") |
|
|
| |
| df = await asyncio.get_event_loop().run_in_executor( |
| self.executor, self._fetch_data_sync, ticker, period, interval |
| ) |
|
|
| if df is None or df.empty: |
| logger.error(f"No data found for ticker {ticker}") |
| return None |
|
|
| |
| df = df[["Open", "High", "Low", "Close", "Volume"]].copy() |
| df.dropna(inplace=True) |
|
|
| if len(df) < 30: |
| logger.warning(f"Insufficient data for {ticker}. Got {len(df)} days, need at least 30") |
| return None |
|
|
| logger.info(f"Successfully fetched {len(df)} data points for {ticker}") |
| return df |
|
|
| except Exception as e: |
| logger.error(f"Error fetching data for {ticker}: {e}") |
| return None |
|
|
| def _fetch_data_sync(self, ticker: str, period: str, interval: str) -> Optional[pd.DataFrame]: |
| """Synchronous data fetching (runs in thread pool).""" |
| try: |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| df = yf.download(ticker, period=period, interval=interval, progress=False) |
| return df |
| except Exception as e: |
| logger.error(f"Error in sync data fetch for {ticker}: {e}") |
| return None |
|
|
| async def predict_next_day_async(self, prices: pd.DataFrame, prediction_length: int = 1, num_samples: int = 20) -> \ |
| Tuple[str, float, List[float]]: |
| """ |
| Predict next day's price using Chronos time series model asynchronously. |
| |
| Args: |
| prices: DataFrame with historical price data |
| prediction_length: Number of future periods to predict |
| num_samples: Number of sample predictions to generate |
| |
| Returns: |
| Tuple of (trend_description, confidence_score, predicted_prices) |
| """ |
| if self.pipeline is None: |
| return "โ Model not initialized", 0.0, [] |
|
|
| if prices is None or len(prices) < 30: |
| return "โ Insufficient data", 0.0, [] |
|
|
| try: |
| |
| result = await asyncio.get_event_loop().run_in_executor( |
| self.executor, self._predict_sync, prices, prediction_length, num_samples |
| ) |
| return result |
|
|
| except Exception as e: |
| logger.error(f"Error during async prediction: {e}") |
| return "โ Prediction error", 0.0, [] |
|
|
| def _predict_sync(self, prices: pd.DataFrame, prediction_length: int, num_samples: int) -> Tuple[ |
| str, float, List[float]]: |
| """Synchronous prediction (runs in thread pool).""" |
| try: |
| |
| closes = prices["Close"].values |
| context_length = min(len(closes), 512) |
| context = closes[-context_length:] |
|
|
| logger.info(f"Using {context_length} data points for prediction") |
|
|
| |
| |
| |
| context_tensor = torch.tensor(context, dtype=torch.float32).reshape(1, -1) |
|
|
| |
| with torch.no_grad(): |
| forecast = self.pipeline.predict( |
| context=context_tensor, |
| prediction_length=prediction_length, |
| num_samples=num_samples |
| ) |
|
|
| |
| predictions = forecast[0, :, 0].numpy() |
|
|
| |
| mean_prediction = np.mean(predictions) |
| std_prediction = np.std(predictions) |
|
|
| current_price = float(closes[-1]) |
| price_change_pct = ((mean_prediction - current_price) / current_price) * 100 |
|
|
| |
| if price_change_pct > 2.0: |
| trend = "๐ Strong Growth Expected" |
| confidence = min(0.9, abs(price_change_pct) / 10.0) |
| elif price_change_pct > 0.5: |
| trend = "๐ Moderate Growth Expected" |
| confidence = min(0.7, abs(price_change_pct) / 5.0) |
| elif price_change_pct < -2.0: |
| trend = "๐ Strong Decline Expected" |
| confidence = min(0.9, abs(price_change_pct) / 10.0) |
| elif price_change_pct < -0.5: |
| trend = "๐ Moderate Decline Expected" |
| confidence = min(0.7, abs(price_change_pct) / 5.0) |
| else: |
| trend = "โก๏ธ Sideways Movement Expected" |
| confidence = 0.5 |
|
|
| |
| variance_factor = min(1.0, std_prediction / current_price) |
| confidence = max(0.1, confidence * (1 - variance_factor)) |
|
|
| logger.info(f"Prediction: ${mean_prediction:.2f} ({price_change_pct:+.2f}%) - {trend}") |
|
|
| return trend, confidence, predictions.tolist() |
|
|
| except Exception as e: |
| logger.error(f"Error in sync prediction: {e}", exc_info=True) |
| return "โ Prediction error", 0.0, [] |
|
|
| async def calculate_technical_indicators_async(self, prices: pd.DataFrame) -> dict: |
| """ |
| Calculate basic technical indicators asynchronously. |
| |
| Args: |
| prices: DataFrame with historical price data |
| |
| Returns: |
| Dictionary with technical indicators |
| """ |
| try: |
| |
| indicators = await asyncio.get_event_loop().run_in_executor( |
| self.executor, self._calculate_indicators_sync, prices |
| ) |
| return indicators |
| except Exception as e: |
| logger.error(f"Error calculating technical indicators: {e}") |
| return {} |
|
|
| |
| def _safe_float(self, val) -> float: |
| """Convert a value to float, safely handling NaN and single-element Series.""" |
| if isinstance(val, pd.Series): |
| |
| if len(val) == 1: |
| val = val.iloc[0] |
| else: |
| |
| val = val.iloc[-1] |
| if pd.isna(val): |
| return 0.0 |
| return float(val) |
|
|
| def _calculate_indicators_sync(self, prices: pd.DataFrame) -> dict[str, float]: |
| """Synchronous indicator calculation - alternative approach.""" |
| try: |
| |
| sma_20 = self._safe_float(prices['Close'].rolling(window=20).mean().iloc[-1]) |
| sma_50 = self._safe_float(prices['Close'].rolling(window=50).mean().iloc[-1]) |
|
|
| |
| current_price = self._safe_float(prices['Close'].iloc[-1]) |
| previous_price = self._safe_float(prices['Close'].iloc[-2]) |
| price_change = ((current_price - previous_price) / previous_price) * 100 if previous_price != 0 else 0.0 |
|
|
| |
| avg_volume = self._safe_float(prices['Volume'].rolling(window=20).mean().iloc[-1]) |
| current_volume = self._safe_float(prices['Volume'].iloc[-1]) |
| volume_ratio = current_volume / avg_volume if avg_volume != 0 else 1.0 |
|
|
| tech_indicators = { |
| 'sma_20': sma_20, |
| 'sma_50': sma_50, |
| 'price_change': price_change, |
| 'volume_ratio': volume_ratio |
| } |
| logger.info(f"Calculated indicators: {tech_indicators}") |
| return tech_indicators |
| except Exception as e: |
| logger.error(f"Error in sync indicator calculation: {e}", exc_info=True) |
| return {} |
|
|
| async def analyze_stock_async(self, ticker: str) -> str: |
| """ |
| Perform complete stock analysis asynchronously. |
| |
| Args: |
| ticker: Stock ticker symbol |
| |
| Returns: |
| Formatted analysis message |
| """ |
| try: |
| |
| prices = await self.fetch_prices_async(ticker) |
|
|
| if prices is None: |
| return f"โ Could not fetch data for {ticker}" |
|
|
| |
| prediction_task = self.predict_next_day_async(prices) |
| indicators_task = self.calculate_technical_indicators_async(prices) |
|
|
| |
| (trend, confidence, predictions), indicators = await asyncio.gather( |
| prediction_task, indicators_task |
| ) |
|
|
| |
| message = await self.create_analysis_message_async( |
| ticker, prices, trend, confidence, predictions, indicators |
| ) |
|
|
| return message |
|
|
| except Exception as e: |
| logger.error(f"Error analyzing {ticker}: {e}") |
| return f"โ Error analyzing {ticker}: {e}" |
|
|
| async def create_analysis_message_async(self, ticker: str, prices: pd.DataFrame, trend: str, |
| confidence: float, predictions: List[float] = None, |
| indicators: dict = None) -> str: |
| """ |
| Create a comprehensive analysis message asynchronously. |
| |
| Args: |
| ticker: Stock ticker symbol |
| prices: DataFrame with price data |
| trend: Predicted trend |
| confidence: Prediction confidence score |
| predictions: List of predicted prices |
| indicators: Technical indicators dictionary |
| |
| Returns: |
| Formatted analysis message |
| """ |
| if prices is None or prices.empty: |
| return f"โ Unable to analyze {ticker} - no data available" |
|
|
| try: |
| last_close = float(prices["Close"].iloc[-1]) |
| last_date = prices.index[-1].strftime('%Y-%m-%d') |
|
|
| message_parts = [ |
| f"๐ **Stock Analysis: {ticker}**", |
| f"๐
Date: {last_date}", |
| f"๐ฐ Current Price: ${last_close:.2f}", |
| f"๐ฎ Prediction: {trend}", |
| f"๐ฏ Confidence: {confidence:.1%}", |
| "" |
| ] |
|
|
| |
| if predictions and len(predictions) > 0: |
| mean_pred = np.mean(predictions) |
| min_pred = np.min(predictions) |
| max_pred = np.max(predictions) |
| price_change = ((mean_pred - last_close) / last_close) * 100 |
|
|
| ''' |
| message_parts.extend([ |
| "๐ฒ **Price Predictions:**", |
| f"โข Expected Price: ${mean_pred:.2f} ({price_change:+.2f}%)", |
| f"โข Price Range: ${min_pred:.2f} - ${max_pred:.2f}", |
| f"โข Prediction Samples: {len(predictions)}", |
| "" |
| ]) |
| ''' |
| message_parts.extend([ |
| "๐ฒ **Price Predictions:**", |
| f"โข Expected Price: ${mean_pred:.2f} ({price_change:+.2f}%)", |
| f"โข Price Range: ${min_pred:.2f} - ${max_pred:.2f}", |
| "" |
| ]) |
|
|
| |
| if indicators: |
| message_parts.extend([ |
| "๐ **Technical Indicators:**", |
| f"โข 20-day SMA: ${indicators.get('sma_20', 0):.2f}", |
| f"โข 50-day SMA: ${indicators.get('sma_50', 0):.2f}", |
| f"โข Daily Change: {indicators.get('price_change', 0):.2f}%", |
| f"โข Volume Ratio: {indicators.get('volume_ratio', 0):.2f}x", |
| "" |
| ]) |
|
|
| message_parts.extend([ |
| "โ ๏ธ **Disclaimer:** This is AI-generated analysis, not financial advice.", |
| "Predictions are based on historical patterns and may not reflect future performance.", |
| "Always do your own research and consult financial advisors before investing." |
| ]) |
|
|
| return "\n".join(message_parts) |
|
|
| except Exception as e: |
| logger.error(f"Error creating message: {e}") |
| return f"โ Error creating analysis for {ticker}" |
|
|
| async def analyze_multiple_stocks(self, tickers: List[str]) -> Dict[str, str]: |
| """ |
| Analyze multiple stocks concurrently. |
| |
| Args: |
| tickers: List of stock ticker symbols |
| |
| Returns: |
| Dictionary mapping tickers to analysis messages |
| """ |
| tasks = [self.analyze_stock_async(ticker) for ticker in tickers] |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| analysis_results = {} |
| for ticker, result in zip(tickers, results): |
| if isinstance(result, Exception): |
| analysis_results[ticker] = f"โ Error analyzing {ticker}: {result}" |
| else: |
| analysis_results[ticker] = result |
|
|
| return analysis_results |
|
|
| async def close(self): |
| """Clean up resources.""" |
| if hasattr(self, 'executor'): |
| self.executor.shutdown(wait=True) |
| logger.info("AsyncStockPredictor resources cleaned up") |
|
|
|
|
| async def main(): |
| """Main async function to demonstrate the stock predictor.""" |
| predictor = AsyncStockPredictor() |
|
|
| try: |
| |
| await predictor.initialize() |
|
|
| |
| tickers = ["AAPL", "GOOGL", "MSFT", "TSLA", "NVDA", "AMD"] |
|
|
| print(f"\n๐ Starting concurrent analysis of {len(tickers)} stocks...") |
| start_time = time.time() |
|
|
| |
| results = await predictor.analyze_multiple_stocks(tickers) |
|
|
| end_time = time.time() |
| total_time = end_time - start_time |
|
|
| |
| for ticker, analysis in results.items(): |
| print(f"\n{'=' * 60}") |
| print(f"Analysis for {ticker}") |
| print('=' * 60) |
| print(analysis) |
|
|
| print(f"\n๐ Analysis completed in {total_time:.2f} seconds") |
| print(f"โก Average time per stock: {total_time / len(tickers):.2f} seconds") |
|
|
| except Exception as e: |
| logger.error(f"Error in main execution: {e}") |
| print(f"โ Application error: {e}") |
|
|
| finally: |
| |
| await predictor.close() |
|
|
|
|
| def run_async_analysis(): |
| """Entry point for running the async analysis.""" |
| try: |
| asyncio.run(main()) |
| except KeyboardInterrupt: |
| print("\n๐ Analysis interrupted by user") |
| except Exception as e: |
| print(f"โ Fatal error: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| run_async_analysis() |
|
|