init
Browse files- Dockerfile +35 -0
- README.md +129 -11
- README_spaces.md +26 -0
- app.py +11 -0
- main.py +268 -0
- models/data_processor.py +178 -0
- models/forecast_models.py +586 -0
- requirements.txt +23 -0
- run.py +53 -0
- test_api.py +125 -0
- train_catboost.py +316 -0
- utils/config.py +45 -0
- utils/logger.py +31 -0
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for AgriPredict Analysis Service
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
gcc \
|
| 10 |
+
g++ \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy requirements first for better caching
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
|
| 16 |
+
# Install Python dependencies
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy application code
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Create non-root user
|
| 23 |
+
RUN useradd --create-home --shell /bin/bash app \
|
| 24 |
+
&& chown -R app:app /app
|
| 25 |
+
USER app
|
| 26 |
+
|
| 27 |
+
# Expose port
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Health check
|
| 31 |
+
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
| 32 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 33 |
+
|
| 34 |
+
# Start the application
|
| 35 |
+
CMD ["python", "main.py"]
|
README.md
CHANGED
|
@@ -1,11 +1,129 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AgriPredict Analysis Service
|
| 2 |
+
|
| 3 |
+
A FastAPI-based service for advanced agricultural demand forecasting using multiple ML models including ensemble methods, statistical models, and machine learning algorithms.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Multi-Model Forecasting**: Ensemble, ARIMA, Exponential Smoothing, CatBoost, and more
|
| 8 |
+
- **Scenario Planning**: Optimistic, pessimistic, and realistic forecast scenarios
|
| 9 |
+
- **Confidence Intervals**: Uncertainty quantification for all predictions
|
| 10 |
+
- **Revenue Projections**: Automatic revenue forecasting based on demand predictions
|
| 11 |
+
- **Real-time Processing**: Asynchronous processing for high performance
|
| 12 |
+
- **RESTful API**: Clean, documented API endpoints
|
| 13 |
+
|
| 14 |
+
## API Endpoints
|
| 15 |
+
|
| 16 |
+
### Health Check
|
| 17 |
+
```
|
| 18 |
+
GET /health
|
| 19 |
+
```
|
| 20 |
+
Returns service health status and version information.
|
| 21 |
+
|
| 22 |
+
### Generate Forecast
|
| 23 |
+
```
|
| 24 |
+
POST /forecast
|
| 25 |
+
```
|
| 26 |
+
Generate demand forecast using specified models and parameters.
|
| 27 |
+
|
| 28 |
+
**Request Body:**
|
| 29 |
+
```json
|
| 30 |
+
{
|
| 31 |
+
"product_id": "string",
|
| 32 |
+
"historical_data": [
|
| 33 |
+
{
|
| 34 |
+
"date": "2024-01-01",
|
| 35 |
+
"quantity": 100.0,
|
| 36 |
+
"price": 25.0
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"days": 30,
|
| 40 |
+
"selling_price": 25.0,
|
| 41 |
+
"models": ["ensemble"],
|
| 42 |
+
"include_confidence": true,
|
| 43 |
+
"scenario": "realistic"
|
| 44 |
+
}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### List Models
|
| 48 |
+
```
|
| 49 |
+
GET /models
|
| 50 |
+
```
|
| 51 |
+
Returns list of available forecasting models.
|
| 52 |
+
|
| 53 |
+
## Models Available
|
| 54 |
+
|
| 55 |
+
1. **Ensemble** - Combines multiple models for best accuracy
|
| 56 |
+
2. **SMA** - Simple Moving Average (basic trend analysis)
|
| 57 |
+
3. **WMA** - Weighted Moving Average (recent data weighted more)
|
| 58 |
+
4. **ES** - Exponential Smoothing (seasonal trend analysis)
|
| 59 |
+
5. **ARIMA** - Statistical time series model
|
| 60 |
+
6. **CatBoost** - Machine learning model (ready for training)
|
| 61 |
+
|
| 62 |
+
## Usage
|
| 63 |
+
|
| 64 |
+
### Local Development
|
| 65 |
+
|
| 66 |
+
1. Install dependencies:
|
| 67 |
+
```bash
|
| 68 |
+
pip install -r requirements.txt
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
2. Run the service:
|
| 72 |
+
```bash
|
| 73 |
+
python main.py
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
The API will be available at `http://localhost:8000`
|
| 77 |
+
|
| 78 |
+
### API Documentation
|
| 79 |
+
|
| 80 |
+
Once running, visit `http://localhost:8000/docs` for interactive API documentation.
|
| 81 |
+
|
| 82 |
+
## Deployment
|
| 83 |
+
|
| 84 |
+
This service is designed to run on Hugging Face Spaces with the following configuration:
|
| 85 |
+
|
| 86 |
+
- **Runtime**: Python 3.10+
|
| 87 |
+
- **Framework**: FastAPI
|
| 88 |
+
- **GPU**: Not required (CPU-only ML models)
|
| 89 |
+
- **Memory**: 2GB minimum recommended
|
| 90 |
+
|
| 91 |
+
## Training the CatBoost Model
|
| 92 |
+
|
| 93 |
+
The CatBoost model is currently using a placeholder implementation. To train it with real data:
|
| 94 |
+
|
| 95 |
+
1. Prepare your training dataset with features like:
|
| 96 |
+
- Historical prices and quantities
|
| 97 |
+
- Date-based features (day of week, month, etc.)
|
| 98 |
+
- Lag features (previous days' data)
|
| 99 |
+
- Rolling statistics
|
| 100 |
+
|
| 101 |
+
2. Train the model using the prepared dataset
|
| 102 |
+
|
| 103 |
+
3. Replace the placeholder implementation in `models/forecast_models.py`
|
| 104 |
+
|
| 105 |
+
## Architecture
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
analysis-service/
|
| 109 |
+
├── main.py # FastAPI application
|
| 110 |
+
├── models/
|
| 111 |
+
│ ├── forecast_models.py # Forecasting algorithms
|
| 112 |
+
│ └── data_processor.py # Data processing utilities
|
| 113 |
+
├── utils/
|
| 114 |
+
│ ├── config.py # Configuration settings
|
| 115 |
+
│ └── logger.py # Logging setup
|
| 116 |
+
└── requirements.txt # Python dependencies
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## Contributing
|
| 120 |
+
|
| 121 |
+
1. Fork the repository
|
| 122 |
+
2. Create a feature branch
|
| 123 |
+
3. Make your changes
|
| 124 |
+
4. Add tests if applicable
|
| 125 |
+
5. Submit a pull request
|
| 126 |
+
|
| 127 |
+
## License
|
| 128 |
+
|
| 129 |
+
MIT License - see LICENSE file for details.
|
README_spaces.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: AgriPredict Analysis Service
|
| 2 |
+
emoji: 🌾
|
| 3 |
+
colorFrom: green
|
| 4 |
+
colorTo: blue
|
| 5 |
+
sdk: gradio
|
| 6 |
+
sdk_version: "4.0.0"
|
| 7 |
+
app_file: main.py
|
| 8 |
+
pinned: false
|
| 9 |
+
|
| 10 |
+
# Hugging Face Spaces Configuration for AgriPredict Analysis Service
|
| 11 |
+
# This service provides advanced agricultural demand forecasting
|
| 12 |
+
|
| 13 |
+
# Python version requirement
|
| 14 |
+
python_version: "3.10"
|
| 15 |
+
|
| 16 |
+
# Build configuration
|
| 17 |
+
build:
|
| 18 |
+
python_version: "3.10"
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
env:
|
| 22 |
+
PORT: 7860
|
| 23 |
+
PYTHONPATH: /app
|
| 24 |
+
|
| 25 |
+
# Startup command
|
| 26 |
+
start_command: "python main.py"
|
app.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AgriPredict Analysis Service
|
| 3 |
+
emoji: 🌾
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: null
|
| 8 |
+
app_file: main.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
main.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AgriPredict Analysis Service
|
| 3 |
+
A FastAPI-based service for agricultural demand forecasting using multiple ML models.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Depends
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime, timedelta
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from contextlib import asynccontextmanager
|
| 17 |
+
|
| 18 |
+
# Import our custom modules
|
| 19 |
+
from models.forecast_models import ForecastEngine
|
| 20 |
+
from models.data_processor import DataProcessor
|
| 21 |
+
from utils.config import settings
|
| 22 |
+
from utils.logger import setup_logger
|
| 23 |
+
|
| 24 |
+
# Setup logging
|
| 25 |
+
logger = setup_logger(__name__)
|
| 26 |
+
|
| 27 |
+
# Lifespan context manager for startup/shutdown events
|
| 28 |
+
@asynccontextmanager
|
| 29 |
+
async def lifespan(app: FastAPI):
|
| 30 |
+
# Startup
|
| 31 |
+
logger.info("Starting AgriPredict Analysis Service")
|
| 32 |
+
yield
|
| 33 |
+
# Shutdown
|
| 34 |
+
logger.info("Shutting down AgriPredict Analysis Service")
|
| 35 |
+
|
| 36 |
+
# Create FastAPI app
|
| 37 |
+
app = FastAPI(
|
| 38 |
+
title="AgriPredict Analysis Service",
|
| 39 |
+
description="Advanced agricultural demand forecasting using ensemble ML models",
|
| 40 |
+
version="1.0.0",
|
| 41 |
+
lifespan=lifespan
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# CORS middleware for Next.js integration
|
| 45 |
+
app.add_middleware(
|
| 46 |
+
CORSMiddleware,
|
| 47 |
+
allow_origins=[
|
| 48 |
+
"http://localhost:3000",
|
| 49 |
+
"http://localhost:3001",
|
| 50 |
+
"https://*.huggingface.co",
|
| 51 |
+
"https://huggingface.co",
|
| 52 |
+
os.getenv("FRONTEND_URL", "*")
|
| 53 |
+
],
|
| 54 |
+
allow_credentials=True,
|
| 55 |
+
allow_methods=["*"],
|
| 56 |
+
allow_headers=["*"],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Data Models
|
| 60 |
+
class DemandData(BaseModel):
|
| 61 |
+
date: str = Field(..., description="ISO date string")
|
| 62 |
+
quantity: float = Field(..., gt=0, description="Demand quantity")
|
| 63 |
+
price: float = Field(..., gt=0, description="Price per unit")
|
| 64 |
+
|
| 65 |
+
class ForecastRequest(BaseModel):
|
| 66 |
+
product_id: str = Field(..., description="Product identifier")
|
| 67 |
+
historical_data: List[DemandData] = Field(..., min_items=3, description="Historical demand data")
|
| 68 |
+
days: int = Field(..., ge=1, le=365, description="Forecast horizon in days")
|
| 69 |
+
selling_price: Optional[float] = Field(None, gt=0, description="Selling price for revenue calculation")
|
| 70 |
+
date_from: Optional[str] = Field(None, description="Start date for historical data filter")
|
| 71 |
+
date_to: Optional[str] = Field(None, description="End date for historical data filter")
|
| 72 |
+
models: Optional[List[str]] = Field(["ensemble"], description="Models to use for forecasting")
|
| 73 |
+
include_confidence: Optional[bool] = Field(True, description="Include confidence intervals")
|
| 74 |
+
scenario: Optional[str] = Field("realistic", description="Forecast scenario")
|
| 75 |
+
|
| 76 |
+
class ForecastDataPoint(BaseModel):
|
| 77 |
+
date: str = Field(..., description="Forecast date")
|
| 78 |
+
predicted_value: float = Field(..., description="Predicted demand/price")
|
| 79 |
+
confidence_lower: Optional[float] = Field(None, description="Lower confidence bound")
|
| 80 |
+
confidence_upper: Optional[float] = Field(None, description="Upper confidence bound")
|
| 81 |
+
model_used: Optional[str] = Field(None, description="Model that generated this prediction")
|
| 82 |
+
|
| 83 |
+
class RevenueProjection(BaseModel):
|
| 84 |
+
date: str = Field(..., description="Projection date")
|
| 85 |
+
projected_quantity: float = Field(..., description="Projected quantity")
|
| 86 |
+
selling_price: float = Field(..., description="Selling price")
|
| 87 |
+
projected_revenue: float = Field(..., description="Projected revenue")
|
| 88 |
+
confidence_lower: Optional[float] = Field(None, description="Lower revenue confidence")
|
| 89 |
+
confidence_upper: Optional[float] = Field(None, description="Upper revenue confidence")
|
| 90 |
+
|
| 91 |
+
class ForecastResponse(BaseModel):
|
| 92 |
+
forecast_data: List[ForecastDataPoint] = Field(..., description="Forecast data points")
|
| 93 |
+
revenue_projection: Optional[List[RevenueProjection]] = Field(None, description="Revenue projections")
|
| 94 |
+
models_used: List[str] = Field(..., description="Models used in forecasting")
|
| 95 |
+
summary: str = Field(..., description="AI-generated summary in Markdown")
|
| 96 |
+
confidence: Optional[float] = Field(None, description="Overall forecast confidence")
|
| 97 |
+
scenario: Optional[str] = Field(None, description="Applied scenario")
|
| 98 |
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
| 99 |
+
|
| 100 |
+
# Dependency injection
|
| 101 |
+
def get_forecast_engine() -> ForecastEngine:
|
| 102 |
+
"""Dependency injection for forecast engine"""
|
| 103 |
+
return ForecastEngine()
|
| 104 |
+
|
| 105 |
+
def get_data_processor() -> DataProcessor:
|
| 106 |
+
"""Dependency injection for data processor"""
|
| 107 |
+
return DataProcessor()
|
| 108 |
+
|
| 109 |
+
# API Endpoints
|
| 110 |
+
@app.get("/health")
|
| 111 |
+
async def health_check():
|
| 112 |
+
"""Health check endpoint"""
|
| 113 |
+
return {
|
| 114 |
+
"status": "healthy",
|
| 115 |
+
"service": "analysis-service",
|
| 116 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 117 |
+
"version": "1.0.0"
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
@app.post("/forecast", response_model=ForecastResponse)
|
| 121 |
+
async def generate_forecast(
|
| 122 |
+
request: ForecastRequest,
|
| 123 |
+
forecast_engine: ForecastEngine = Depends(get_forecast_engine),
|
| 124 |
+
data_processor: DataProcessor = Depends(get_data_processor)
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Generate demand forecast using ensemble ML models
|
| 128 |
+
"""
|
| 129 |
+
try:
|
| 130 |
+
logger.info(f"Generating forecast for product {request.product_id}")
|
| 131 |
+
|
| 132 |
+
# Process and validate data
|
| 133 |
+
df = data_processor.process_historical_data(request.historical_data)
|
| 134 |
+
|
| 135 |
+
if len(df) < 3:
|
| 136 |
+
raise HTTPException(
|
| 137 |
+
status_code=400,
|
| 138 |
+
detail="Insufficient historical data. Need at least 3 data points."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Generate forecast
|
| 142 |
+
forecast_result = await forecast_engine.generate_forecast(
|
| 143 |
+
df=df,
|
| 144 |
+
days=request.days,
|
| 145 |
+
models=request.models or ["ensemble"],
|
| 146 |
+
include_confidence=request.include_confidence,
|
| 147 |
+
scenario=request.scenario
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Calculate revenue projection if selling price provided
|
| 151 |
+
revenue_projection = None
|
| 152 |
+
if request.selling_price and request.selling_price > 0:
|
| 153 |
+
revenue_projection = forecast_engine.calculate_revenue_projection(
|
| 154 |
+
forecast_data=forecast_result["forecast_data"],
|
| 155 |
+
selling_price=request.selling_price,
|
| 156 |
+
historical_data=df
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Generate AI summary
|
| 160 |
+
summary = forecast_engine.generate_summary(
|
| 161 |
+
forecast_data=forecast_result["forecast_data"],
|
| 162 |
+
historical_data=df,
|
| 163 |
+
models_used=forecast_result["models_used"],
|
| 164 |
+
scenario=request.scenario
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Calculate overall confidence
|
| 168 |
+
confidence = forecast_engine.calculate_overall_confidence(
|
| 169 |
+
forecast_data=forecast_result["forecast_data"]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Prepare metadata
|
| 173 |
+
metadata = {
|
| 174 |
+
"data_points": len(df),
|
| 175 |
+
"forecast_horizon": request.days,
|
| 176 |
+
"product_id": request.product_id,
|
| 177 |
+
"generated_at": datetime.utcnow().isoformat(),
|
| 178 |
+
"scenario": request.scenario
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
response = ForecastResponse(
|
| 182 |
+
forecast_data=forecast_result["forecast_data"],
|
| 183 |
+
revenue_projection=revenue_projection,
|
| 184 |
+
models_used=forecast_result["models_used"],
|
| 185 |
+
summary=summary,
|
| 186 |
+
confidence=confidence,
|
| 187 |
+
scenario=request.scenario,
|
| 188 |
+
metadata=metadata
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
logger.info(f"Successfully generated forecast for product {request.product_id}")
|
| 192 |
+
return response
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.error(f"Forecast generation failed: {str(e)}")
|
| 196 |
+
raise HTTPException(
|
| 197 |
+
status_code=500,
|
| 198 |
+
detail=f"Forecast generation failed: {str(e)}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@app.get("/models")
|
| 202 |
+
async def list_available_models():
|
| 203 |
+
"""List all available forecasting models"""
|
| 204 |
+
return {
|
| 205 |
+
"models": [
|
| 206 |
+
{
|
| 207 |
+
"id": "ensemble",
|
| 208 |
+
"name": "Ensemble (Recommended)",
|
| 209 |
+
"description": "Combines multiple models for best accuracy",
|
| 210 |
+
"type": "ensemble"
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"id": "sma",
|
| 214 |
+
"name": "Simple Moving Average",
|
| 215 |
+
"description": "Basic trend analysis",
|
| 216 |
+
"type": "statistical"
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"id": "wma",
|
| 220 |
+
"name": "Weighted Moving Average",
|
| 221 |
+
"description": "Recent data weighted more",
|
| 222 |
+
"type": "statistical"
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"id": "es",
|
| 226 |
+
"name": "Exponential Smoothing",
|
| 227 |
+
"description": "Seasonal trend analysis",
|
| 228 |
+
"type": "statistical"
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"id": "arima",
|
| 232 |
+
"name": "ARIMA",
|
| 233 |
+
"description": "Statistical time series model",
|
| 234 |
+
"type": "statistical"
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"id": "catboost",
|
| 238 |
+
"name": "CatBoost",
|
| 239 |
+
"description": "Machine learning model",
|
| 240 |
+
"type": "ml"
|
| 241 |
+
}
|
| 242 |
+
]
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
# Error handlers
|
| 246 |
+
@app.exception_handler(HTTPException)
|
| 247 |
+
async def http_exception_handler(request, exc):
|
| 248 |
+
return JSONResponse(
|
| 249 |
+
status_code=exc.status_code,
|
| 250 |
+
content={"detail": exc.detail}
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
@app.exception_handler(Exception)
|
| 254 |
+
async def general_exception_handler(request, exc):
|
| 255 |
+
logger.error(f"Unhandled exception: {str(exc)}")
|
| 256 |
+
return JSONResponse(
|
| 257 |
+
status_code=500,
|
| 258 |
+
content={"detail": "Internal server error"}
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
import uvicorn
|
| 263 |
+
uvicorn.run(
|
| 264 |
+
"main:app",
|
| 265 |
+
host="0.0.0.0",
|
| 266 |
+
port=int(os.getenv("PORT", 8000)),
|
| 267 |
+
reload=True
|
| 268 |
+
)
|
models/data_processor.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data processing utilities for AgriPredict Analysis Service
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from utils.logger import setup_logger
|
| 10 |
+
from utils.config import settings
|
| 11 |
+
|
| 12 |
+
logger = setup_logger(__name__)
|
| 13 |
+
|
| 14 |
+
class DataProcessor:
|
| 15 |
+
"""Handles data processing and validation for forecasting"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.logger = logger
|
| 19 |
+
|
| 20 |
+
def process_historical_data(self, historical_data: List[Dict[str, Any]]) -> pd.DataFrame:
|
| 21 |
+
"""
|
| 22 |
+
Process and validate historical demand data
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
historical_data: List of demand data points
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Processed pandas DataFrame
|
| 29 |
+
"""
|
| 30 |
+
try:
|
| 31 |
+
self.logger.info(f"Processing {len(historical_data)} historical data points")
|
| 32 |
+
|
| 33 |
+
# Convert to DataFrame
|
| 34 |
+
df = pd.DataFrame(historical_data)
|
| 35 |
+
|
| 36 |
+
# Validate required columns
|
| 37 |
+
required_columns = ['date', 'quantity', 'price']
|
| 38 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 39 |
+
if missing_columns:
|
| 40 |
+
raise ValueError(f"Missing required columns: {missing_columns}")
|
| 41 |
+
|
| 42 |
+
# Convert date column
|
| 43 |
+
df['date'] = pd.to_datetime(df['date'])
|
| 44 |
+
|
| 45 |
+
# Validate data types and ranges
|
| 46 |
+
df['quantity'] = pd.to_numeric(df['quantity'], errors='coerce')
|
| 47 |
+
df['price'] = pd.to_numeric(df['price'], errors='coerce')
|
| 48 |
+
|
| 49 |
+
# Remove invalid data
|
| 50 |
+
df = df.dropna(subset=['quantity', 'price'])
|
| 51 |
+
df = df[df['quantity'] > 0]
|
| 52 |
+
df = df[df['price'] > 0]
|
| 53 |
+
|
| 54 |
+
# Sort by date
|
| 55 |
+
df = df.sort_values('date').reset_index(drop=True)
|
| 56 |
+
|
| 57 |
+
# Remove duplicates based on date
|
| 58 |
+
df = df.drop_duplicates(subset=['date'], keep='last')
|
| 59 |
+
|
| 60 |
+
# Limit data points if too many
|
| 61 |
+
if len(df) > settings.MAX_DATA_POINTS:
|
| 62 |
+
self.logger.warning(f"Limiting data from {len(df)} to {settings.MAX_DATA_POINTS} points")
|
| 63 |
+
df = df.tail(settings.MAX_DATA_POINTS)
|
| 64 |
+
|
| 65 |
+
self.logger.info(f"Successfully processed {len(df)} data points")
|
| 66 |
+
return df
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
self.logger.error(f"Data processing failed: {str(e)}")
|
| 70 |
+
raise
|
| 71 |
+
|
| 72 |
+
def validate_data_quality(self, df: pd.DataFrame) -> Dict[str, Any]:
|
| 73 |
+
"""
|
| 74 |
+
Validate data quality and return metrics
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
df: Processed DataFrame
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Dictionary with quality metrics
|
| 81 |
+
"""
|
| 82 |
+
try:
|
| 83 |
+
quality_metrics = {
|
| 84 |
+
'total_points': len(df),
|
| 85 |
+
'date_range': {
|
| 86 |
+
'start': df['date'].min().isoformat() if len(df) > 0 else None,
|
| 87 |
+
'end': df['date'].max().isoformat() if len(df) > 0 else None
|
| 88 |
+
},
|
| 89 |
+
'missing_values': {
|
| 90 |
+
'quantity': df['quantity'].isnull().sum(),
|
| 91 |
+
'price': df['price'].isnull().sum()
|
| 92 |
+
},
|
| 93 |
+
'outliers': {
|
| 94 |
+
'quantity': self._detect_outliers(df['quantity']),
|
| 95 |
+
'price': self._detect_outliers(df['price'])
|
| 96 |
+
},
|
| 97 |
+
'data_completeness': self._calculate_completeness(df)
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
return quality_metrics
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
self.logger.error(f"Quality validation failed: {str(e)}")
|
| 104 |
+
return {}
|
| 105 |
+
|
| 106 |
+
def _detect_outliers(self, series: pd.Series) -> int:
|
| 107 |
+
"""Detect outliers using IQR method"""
|
| 108 |
+
try:
|
| 109 |
+
Q1 = series.quantile(0.25)
|
| 110 |
+
Q3 = series.quantile(0.75)
|
| 111 |
+
IQR = Q3 - Q1
|
| 112 |
+
lower_bound = Q1 - 1.5 * IQR
|
| 113 |
+
upper_bound = Q3 + 1.5 * IQR
|
| 114 |
+
|
| 115 |
+
outliers = ((series < lower_bound) | (series > upper_bound)).sum()
|
| 116 |
+
return int(outliers)
|
| 117 |
+
except:
|
| 118 |
+
return 0
|
| 119 |
+
|
| 120 |
+
def _calculate_completeness(self, df: pd.DataFrame) -> float:
|
| 121 |
+
"""Calculate data completeness percentage"""
|
| 122 |
+
try:
|
| 123 |
+
total_cells = len(df) * 2 # quantity and price columns
|
| 124 |
+
missing_cells = df[['quantity', 'price']].isnull().sum().sum()
|
| 125 |
+
completeness = ((total_cells - missing_cells) / total_cells) * 100
|
| 126 |
+
return round(completeness, 2)
|
| 127 |
+
except:
|
| 128 |
+
return 0.0
|
| 129 |
+
|
| 130 |
+
def prepare_features_for_ml(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 131 |
+
"""
|
| 132 |
+
Prepare features for machine learning models
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
df: Processed DataFrame
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
DataFrame with engineered features
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
# Create feature engineering
|
| 142 |
+
feature_df = df.copy()
|
| 143 |
+
|
| 144 |
+
# Date-based features
|
| 145 |
+
feature_df['day_of_week'] = feature_df['date'].dt.dayofweek
|
| 146 |
+
feature_df['month'] = feature_df['date'].dt.month
|
| 147 |
+
feature_df['day_of_month'] = feature_df['date'].dt.day
|
| 148 |
+
feature_df['quarter'] = feature_df['date'].dt.quarter
|
| 149 |
+
|
| 150 |
+
# Lag features
|
| 151 |
+
for lag in [1, 7, 14, 30]:
|
| 152 |
+
if len(feature_df) > lag:
|
| 153 |
+
feature_df[f'price_lag_{lag}'] = feature_df['price'].shift(lag)
|
| 154 |
+
feature_df[f'quantity_lag_{lag}'] = feature_df['quantity'].shift(lag)
|
| 155 |
+
|
| 156 |
+
# Rolling statistics
|
| 157 |
+
for window in [7, 14, 30]:
|
| 158 |
+
if len(feature_df) > window:
|
| 159 |
+
feature_df[f'price_rolling_mean_{window}'] = feature_df['price'].rolling(window).mean()
|
| 160 |
+
feature_df[f'price_rolling_std_{window}'] = feature_df['price'].rolling(window).std()
|
| 161 |
+
feature_df[f'quantity_rolling_mean_{window}'] = feature_df['quantity'].rolling(window).mean()
|
| 162 |
+
|
| 163 |
+
# Price change features
|
| 164 |
+
feature_df['price_change'] = feature_df['price'].pct_change()
|
| 165 |
+
feature_df['price_change_7d'] = feature_df['price'].pct_change(7)
|
| 166 |
+
|
| 167 |
+
# Volume-weighted features
|
| 168 |
+
feature_df['value'] = feature_df['quantity'] * feature_df['price']
|
| 169 |
+
|
| 170 |
+
# Drop rows with NaN values created by lag features
|
| 171 |
+
feature_df = feature_df.dropna()
|
| 172 |
+
|
| 173 |
+
self.logger.info(f"Created {len(feature_df.columns) - len(df.columns)} additional features")
|
| 174 |
+
return feature_df
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
self.logger.error(f"Feature engineering failed: {str(e)}")
|
| 178 |
+
return df
|
models/forecast_models.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Forecasting models for AgriPredict Analysis Service
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
import asyncio
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 12 |
+
import traceback
|
| 13 |
+
|
| 14 |
+
# Import ML libraries (will be available when deployed)
|
| 15 |
+
try:
|
| 16 |
+
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
| 17 |
+
from statsmodels.tsa.arima.model import ARIMA
|
| 18 |
+
from catboost import CatBoostRegressor
|
| 19 |
+
STATS_MODELS_AVAILABLE = True
|
| 20 |
+
CATBOOST_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
STATS_MODELS_AVAILABLE = False
|
| 23 |
+
CATBOOST_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
from utils.logger import setup_logger
|
| 26 |
+
from utils.config import settings
|
| 27 |
+
|
| 28 |
+
logger = setup_logger(__name__)
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ForecastResult:
|
| 32 |
+
"""Container for forecast results"""
|
| 33 |
+
values: List[float]
|
| 34 |
+
confidence_lower: Optional[List[float]] = None
|
| 35 |
+
confidence_upper: Optional[List[float]] = None
|
| 36 |
+
model_name: str = ""
|
| 37 |
+
|
| 38 |
+
class ForecastEngine:
|
| 39 |
+
"""Main forecasting engine with multiple models"""
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
self.logger = logger
|
| 43 |
+
self.executor = ThreadPoolExecutor(max_workers=4)
|
| 44 |
+
|
| 45 |
+
async def generate_forecast(
|
| 46 |
+
self,
|
| 47 |
+
df: pd.DataFrame,
|
| 48 |
+
days: int,
|
| 49 |
+
models: List[str],
|
| 50 |
+
include_confidence: bool = True,
|
| 51 |
+
scenario: str = "realistic"
|
| 52 |
+
) -> Dict[str, Any]:
|
| 53 |
+
"""
|
| 54 |
+
Generate forecast using specified models
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
df: Historical data DataFrame
|
| 58 |
+
days: Number of days to forecast
|
| 59 |
+
models: List of model names to use
|
| 60 |
+
include_confidence: Whether to include confidence intervals
|
| 61 |
+
scenario: Forecast scenario (optimistic, pessimistic, realistic)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dictionary with forecast results
|
| 65 |
+
"""
|
| 66 |
+
try:
|
| 67 |
+
self.logger.info(f"Generating {days}-day forecast using models: {models}")
|
| 68 |
+
|
| 69 |
+
# Apply scenario adjustment
|
| 70 |
+
scenario_multiplier = self._get_scenario_multiplier(scenario)
|
| 71 |
+
df = df.copy()
|
| 72 |
+
df['price'] = df['price'] * scenario_multiplier
|
| 73 |
+
|
| 74 |
+
# Generate forecasts from different models
|
| 75 |
+
forecast_tasks = []
|
| 76 |
+
model_results = {}
|
| 77 |
+
|
| 78 |
+
for model_name in models:
|
| 79 |
+
if model_name.lower() == 'ensemble':
|
| 80 |
+
# Ensemble uses all available models
|
| 81 |
+
continue
|
| 82 |
+
elif hasattr(self, f'_generate_{model_name.lower()}_forecast'):
|
| 83 |
+
task = asyncio.get_event_loop().run_in_executor(
|
| 84 |
+
self.executor,
|
| 85 |
+
getattr(self, f'_generate_{model_name.lower()}_forecast'),
|
| 86 |
+
df.copy(),
|
| 87 |
+
days,
|
| 88 |
+
include_confidence
|
| 89 |
+
)
|
| 90 |
+
forecast_tasks.append((model_name, task))
|
| 91 |
+
|
| 92 |
+
# Wait for all model forecasts
|
| 93 |
+
if forecast_tasks:
|
| 94 |
+
results = await asyncio.gather(*[task for _, task in forecast_tasks], return_exceptions=True)
|
| 95 |
+
|
| 96 |
+
for (model_name, _), result in zip(forecast_tasks, results):
|
| 97 |
+
if isinstance(result, Exception):
|
| 98 |
+
self.logger.warning(f"Model {model_name} failed: {str(result)}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
if result and result.values:
|
| 102 |
+
model_results[model_name] = result
|
| 103 |
+
|
| 104 |
+
# If no models succeeded, use fallback
|
| 105 |
+
if not model_results:
|
| 106 |
+
self.logger.warning("All models failed, using fallback forecast")
|
| 107 |
+
fallback_result = self._generate_fallback_forecast(df, days)
|
| 108 |
+
model_results['Fallback'] = fallback_result
|
| 109 |
+
|
| 110 |
+
# Generate ensemble forecast if requested
|
| 111 |
+
if 'ensemble' in [m.lower() for m in models]:
|
| 112 |
+
ensemble_result = self._generate_ensemble_forecast(model_results, days, include_confidence)
|
| 113 |
+
model_results['Ensemble'] = ensemble_result
|
| 114 |
+
|
| 115 |
+
# Prepare final forecast data
|
| 116 |
+
final_forecast = self._prepare_forecast_data(model_results, df, days)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"forecast_data": final_forecast,
|
| 120 |
+
"models_used": list(model_results.keys()),
|
| 121 |
+
"scenario": scenario
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
self.logger.error(f"Forecast generation failed: {str(e)}")
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
def _get_scenario_multiplier(self, scenario: str) -> float:
|
| 129 |
+
"""Get multiplier for scenario adjustment"""
|
| 130 |
+
multipliers = {
|
| 131 |
+
'optimistic': 1.1, # 10% increase
|
| 132 |
+
'pessimistic': 0.9, # 10% decrease
|
| 133 |
+
'realistic': 1.0 # No change
|
| 134 |
+
}
|
| 135 |
+
return multipliers.get(scenario.lower(), 1.0)
|
| 136 |
+
|
| 137 |
+
def _generate_sma_forecast(
|
| 138 |
+
self,
|
| 139 |
+
df: pd.DataFrame,
|
| 140 |
+
days: int,
|
| 141 |
+
include_confidence: bool = True
|
| 142 |
+
) -> ForecastResult:
|
| 143 |
+
"""Simple Moving Average forecast"""
|
| 144 |
+
try:
|
| 145 |
+
if len(df) < 7:
|
| 146 |
+
raise ValueError("Insufficient data for SMA")
|
| 147 |
+
|
| 148 |
+
window = min(7, len(df))
|
| 149 |
+
sma_value = df['price'].rolling(window=window).mean().iloc[-1]
|
| 150 |
+
|
| 151 |
+
if pd.isna(sma_value):
|
| 152 |
+
sma_value = df['price'].mean()
|
| 153 |
+
|
| 154 |
+
values = [float(sma_value)] * days
|
| 155 |
+
|
| 156 |
+
# Simple confidence interval
|
| 157 |
+
std_dev = df['price'].std()
|
| 158 |
+
confidence_lower = [v - std_dev * 0.5 for v in values] if include_confidence else None
|
| 159 |
+
confidence_upper = [v + std_dev * 0.5 for v in values] if include_confidence else None
|
| 160 |
+
|
| 161 |
+
return ForecastResult(
|
| 162 |
+
values=values,
|
| 163 |
+
confidence_lower=confidence_lower,
|
| 164 |
+
confidence_upper=confidence_upper,
|
| 165 |
+
model_name="SMA"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
self.logger.error(f"SMA forecast failed: {str(e)}")
|
| 170 |
+
raise
|
| 171 |
+
|
| 172 |
+
def _generate_wma_forecast(
|
| 173 |
+
self,
|
| 174 |
+
df: pd.DataFrame,
|
| 175 |
+
days: int,
|
| 176 |
+
include_confidence: bool = True
|
| 177 |
+
) -> ForecastResult:
|
| 178 |
+
"""Weighted Moving Average forecast"""
|
| 179 |
+
try:
|
| 180 |
+
if len(df) < 7:
|
| 181 |
+
raise ValueError("Insufficient data for WMA")
|
| 182 |
+
|
| 183 |
+
window = min(7, len(df))
|
| 184 |
+
weights = np.arange(1, window + 1)
|
| 185 |
+
weights = weights / weights.sum()
|
| 186 |
+
|
| 187 |
+
wma_value = (df['price'].tail(window) * weights).sum()
|
| 188 |
+
|
| 189 |
+
if pd.isna(wma_value):
|
| 190 |
+
wma_value = df['price'].mean()
|
| 191 |
+
|
| 192 |
+
values = [float(wma_value)] * days
|
| 193 |
+
|
| 194 |
+
# Confidence interval
|
| 195 |
+
std_dev = df['price'].std()
|
| 196 |
+
confidence_lower = [v - std_dev * 0.3 for v in values] if include_confidence else None
|
| 197 |
+
confidence_upper = [v + std_dev * 0.3 for v in values] if include_confidence else None
|
| 198 |
+
|
| 199 |
+
return ForecastResult(
|
| 200 |
+
values=values,
|
| 201 |
+
confidence_lower=confidence_lower,
|
| 202 |
+
confidence_upper=confidence_upper,
|
| 203 |
+
model_name="WMA"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
self.logger.error(f"WMA forecast failed: {str(e)}")
|
| 208 |
+
raise
|
| 209 |
+
|
| 210 |
+
def _generate_es_forecast(
|
| 211 |
+
self,
|
| 212 |
+
df: pd.DataFrame,
|
| 213 |
+
days: int,
|
| 214 |
+
include_confidence: bool = True
|
| 215 |
+
) -> ForecastResult:
|
| 216 |
+
"""Exponential Smoothing forecast"""
|
| 217 |
+
try:
|
| 218 |
+
if not STATS_MODELS_AVAILABLE:
|
| 219 |
+
raise ImportError("statsmodels not available")
|
| 220 |
+
|
| 221 |
+
if len(df) < 7:
|
| 222 |
+
raise ValueError("Insufficient data for Exponential Smoothing")
|
| 223 |
+
|
| 224 |
+
# Prepare data for exponential smoothing
|
| 225 |
+
ts_data = df.set_index('date')['price']
|
| 226 |
+
|
| 227 |
+
model = ExponentialSmoothing(ts_data, seasonal='add', seasonal_periods=7)
|
| 228 |
+
fitted_model = model.fit()
|
| 229 |
+
|
| 230 |
+
forecast = fitted_model.forecast(days)
|
| 231 |
+
values = forecast.values.tolist()
|
| 232 |
+
|
| 233 |
+
# Get confidence intervals if available
|
| 234 |
+
if include_confidence:
|
| 235 |
+
try:
|
| 236 |
+
pred = fitted_model.get_prediction()
|
| 237 |
+
confidence_intervals = pred.conf_int()
|
| 238 |
+
confidence_lower = confidence_intervals.iloc[:, 0].tail(days).values.tolist()
|
| 239 |
+
confidence_upper = confidence_intervals.iloc[:, 1].tail(days).values.tolist()
|
| 240 |
+
except:
|
| 241 |
+
# Fallback confidence interval
|
| 242 |
+
std_dev = df['price'].std()
|
| 243 |
+
confidence_lower = [v - std_dev for v in values]
|
| 244 |
+
confidence_upper = [v + std_dev for v in values]
|
| 245 |
+
else:
|
| 246 |
+
confidence_lower = None
|
| 247 |
+
confidence_upper = None
|
| 248 |
+
|
| 249 |
+
return ForecastResult(
|
| 250 |
+
values=values,
|
| 251 |
+
confidence_lower=confidence_lower,
|
| 252 |
+
confidence_upper=confidence_upper,
|
| 253 |
+
model_name="ES"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
self.logger.error(f"ES forecast failed: {str(e)}")
|
| 258 |
+
raise
|
| 259 |
+
|
| 260 |
+
def _generate_arima_forecast(
|
| 261 |
+
self,
|
| 262 |
+
df: pd.DataFrame,
|
| 263 |
+
days: int,
|
| 264 |
+
include_confidence: bool = True
|
| 265 |
+
) -> ForecastResult:
|
| 266 |
+
"""ARIMA forecast"""
|
| 267 |
+
try:
|
| 268 |
+
if not STATS_MODELS_AVAILABLE:
|
| 269 |
+
raise ImportError("statsmodels not available")
|
| 270 |
+
|
| 271 |
+
if len(df) < 10:
|
| 272 |
+
raise ValueError("Insufficient data for ARIMA")
|
| 273 |
+
|
| 274 |
+
# Prepare data
|
| 275 |
+
ts_data = df.set_index('date')['price']
|
| 276 |
+
|
| 277 |
+
model = ARIMA(ts_data, order=(5, 1, 0))
|
| 278 |
+
fitted_model = model.fit()
|
| 279 |
+
|
| 280 |
+
forecast = fitted_model.forecast(days)
|
| 281 |
+
values = forecast.values.tolist()
|
| 282 |
+
|
| 283 |
+
# Get confidence intervals
|
| 284 |
+
if include_confidence:
|
| 285 |
+
try:
|
| 286 |
+
pred = fitted_model.get_forecast(days)
|
| 287 |
+
confidence_intervals = pred.conf_int()
|
| 288 |
+
confidence_lower = confidence_intervals.iloc[:, 0].values.tolist()
|
| 289 |
+
confidence_upper = confidence_intervals.iloc[:, 1].values.tolist()
|
| 290 |
+
except:
|
| 291 |
+
# Fallback confidence interval
|
| 292 |
+
std_dev = df['price'].std()
|
| 293 |
+
confidence_lower = [v - std_dev for v in values]
|
| 294 |
+
confidence_upper = [v + std_dev for v in values]
|
| 295 |
+
else:
|
| 296 |
+
confidence_lower = None
|
| 297 |
+
confidence_upper = None
|
| 298 |
+
|
| 299 |
+
return ForecastResult(
|
| 300 |
+
values=values,
|
| 301 |
+
confidence_lower=confidence_lower,
|
| 302 |
+
confidence_upper=confidence_upper,
|
| 303 |
+
model_name="ARIMA"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
self.logger.error(f"ARIMA forecast failed: {str(e)}")
|
| 308 |
+
raise
|
| 309 |
+
|
| 310 |
+
def _generate_catboost_forecast(
|
| 311 |
+
self,
|
| 312 |
+
df: pd.DataFrame,
|
| 313 |
+
days: int,
|
| 314 |
+
include_confidence: bool = True
|
| 315 |
+
) -> ForecastResult:
|
| 316 |
+
"""CatBoost forecast (placeholder for future training)"""
|
| 317 |
+
try:
|
| 318 |
+
if not CATBOOST_AVAILABLE:
|
| 319 |
+
raise ImportError("CatBoost not available")
|
| 320 |
+
|
| 321 |
+
if len(df) < 10:
|
| 322 |
+
raise ValueError("Insufficient data for CatBoost")
|
| 323 |
+
|
| 324 |
+
# For now, use a simple fallback since model isn't trained yet
|
| 325 |
+
# This will be replaced with actual trained model later
|
| 326 |
+
self.logger.info("Using CatBoost placeholder (model not trained yet)")
|
| 327 |
+
|
| 328 |
+
# Simple trend-based forecast as placeholder
|
| 329 |
+
recent_trend = df['price'].pct_change().mean()
|
| 330 |
+
last_price = df['price'].iloc[-1]
|
| 331 |
+
|
| 332 |
+
values = []
|
| 333 |
+
for i in range(days):
|
| 334 |
+
trend_factor = 1 + (recent_trend * (i + 1) / days)
|
| 335 |
+
predicted_price = last_price * trend_factor
|
| 336 |
+
values.append(float(predicted_price))
|
| 337 |
+
|
| 338 |
+
# Simple confidence intervals
|
| 339 |
+
std_dev = df['price'].std()
|
| 340 |
+
confidence_lower = [v - std_dev for v in values] if include_confidence else None
|
| 341 |
+
confidence_upper = [v + std_dev for v in values] if include_confidence else None
|
| 342 |
+
|
| 343 |
+
return ForecastResult(
|
| 344 |
+
values=values,
|
| 345 |
+
confidence_lower=confidence_lower,
|
| 346 |
+
confidence_upper=confidence_upper,
|
| 347 |
+
model_name="CatBoost"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
self.logger.error(f"CatBoost forecast failed: {str(e)}")
|
| 352 |
+
raise
|
| 353 |
+
|
| 354 |
+
def _generate_fallback_forecast(self, df: pd.DataFrame, days: int) -> ForecastResult:
|
| 355 |
+
"""Fallback forecast using simple average"""
|
| 356 |
+
try:
|
| 357 |
+
avg_price = df['price'].mean()
|
| 358 |
+
values = [float(avg_price)] * days
|
| 359 |
+
|
| 360 |
+
# Wide confidence intervals for fallback
|
| 361 |
+
std_dev = df['price'].std() if len(df) > 1 else avg_price * 0.1
|
| 362 |
+
confidence_lower = [v - std_dev * 2 for v in values]
|
| 363 |
+
confidence_upper = [v + std_dev * 2 for v in values]
|
| 364 |
+
|
| 365 |
+
return ForecastResult(
|
| 366 |
+
values=values,
|
| 367 |
+
confidence_lower=confidence_lower,
|
| 368 |
+
confidence_upper=confidence_upper,
|
| 369 |
+
model_name="Fallback"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
self.logger.error(f"Fallback forecast failed: {str(e)}")
|
| 374 |
+
# Ultimate fallback
|
| 375 |
+
return ForecastResult(
|
| 376 |
+
values=[100.0] * days,
|
| 377 |
+
confidence_lower=[80.0] * days,
|
| 378 |
+
confidence_upper=[120.0] * days,
|
| 379 |
+
model_name="Fallback"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def _generate_ensemble_forecast(
|
| 383 |
+
self,
|
| 384 |
+
model_results: Dict[str, ForecastResult],
|
| 385 |
+
days: int,
|
| 386 |
+
include_confidence: bool = True
|
| 387 |
+
) -> ForecastResult:
|
| 388 |
+
"""Generate ensemble forecast from multiple models"""
|
| 389 |
+
try:
|
| 390 |
+
if not model_results:
|
| 391 |
+
raise ValueError("No model results available for ensemble")
|
| 392 |
+
|
| 393 |
+
# Average predictions from all models
|
| 394 |
+
all_values = []
|
| 395 |
+
for result in model_results.values():
|
| 396 |
+
if len(result.values) >= days:
|
| 397 |
+
all_values.append(result.values[:days])
|
| 398 |
+
|
| 399 |
+
if not all_values:
|
| 400 |
+
raise ValueError("No valid predictions for ensemble")
|
| 401 |
+
|
| 402 |
+
# Calculate ensemble predictions
|
| 403 |
+
ensemble_values = []
|
| 404 |
+
for i in range(days):
|
| 405 |
+
day_predictions = [values[i] for values in all_values if i < len(values)]
|
| 406 |
+
ensemble_values.append(np.mean(day_predictions))
|
| 407 |
+
|
| 408 |
+
# Calculate ensemble confidence intervals
|
| 409 |
+
if include_confidence:
|
| 410 |
+
all_lower = []
|
| 411 |
+
all_upper = []
|
| 412 |
+
for result in model_results.values():
|
| 413 |
+
if result.confidence_lower and len(result.confidence_lower) >= days:
|
| 414 |
+
all_lower.append(result.confidence_lower[:days])
|
| 415 |
+
if result.confidence_upper and len(result.confidence_upper) >= days:
|
| 416 |
+
all_upper.append(result.confidence_upper[:days])
|
| 417 |
+
|
| 418 |
+
if all_lower and all_upper:
|
| 419 |
+
confidence_lower = [np.mean([lower[i] for lower in all_lower]) for i in range(days)]
|
| 420 |
+
confidence_upper = [np.mean([upper[i] for upper in all_upper]) for i in range(days)]
|
| 421 |
+
else:
|
| 422 |
+
# Fallback confidence intervals
|
| 423 |
+
std_dev = np.std(ensemble_values)
|
| 424 |
+
confidence_lower = [v - std_dev for v in ensemble_values]
|
| 425 |
+
confidence_upper = [v + std_dev for v in ensemble_values]
|
| 426 |
+
else:
|
| 427 |
+
confidence_lower = None
|
| 428 |
+
confidence_upper = None
|
| 429 |
+
|
| 430 |
+
return ForecastResult(
|
| 431 |
+
values=ensemble_values,
|
| 432 |
+
confidence_lower=confidence_lower,
|
| 433 |
+
confidence_upper=confidence_upper,
|
| 434 |
+
model_name="Ensemble"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
except Exception as e:
|
| 438 |
+
self.logger.error(f"Ensemble forecast failed: {str(e)}")
|
| 439 |
+
raise
|
| 440 |
+
|
| 441 |
+
def _prepare_forecast_data(
|
| 442 |
+
self,
|
| 443 |
+
model_results: Dict[str, ForecastResult],
|
| 444 |
+
df: pd.DataFrame,
|
| 445 |
+
days: int
|
| 446 |
+
) -> List[Dict[str, Any]]:
|
| 447 |
+
"""Prepare final forecast data for API response"""
|
| 448 |
+
try:
|
| 449 |
+
last_date = df['date'].max()
|
| 450 |
+
|
| 451 |
+
forecast_data = []
|
| 452 |
+
for i in range(days):
|
| 453 |
+
forecast_date = last_date + timedelta(days=i+1)
|
| 454 |
+
|
| 455 |
+
# Use ensemble if available, otherwise use first available model
|
| 456 |
+
if 'Ensemble' in model_results:
|
| 457 |
+
result = model_results['Ensemble']
|
| 458 |
+
else:
|
| 459 |
+
result = next(iter(model_results.values()))
|
| 460 |
+
|
| 461 |
+
data_point = {
|
| 462 |
+
"date": forecast_date.isoformat(),
|
| 463 |
+
"predicted_value": round(result.values[i], 2),
|
| 464 |
+
"model_used": result.model_name
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
if result.confidence_lower and i < len(result.confidence_lower):
|
| 468 |
+
data_point["confidence_lower"] = round(result.confidence_lower[i], 2)
|
| 469 |
+
|
| 470 |
+
if result.confidence_upper and i < len(result.confidence_upper):
|
| 471 |
+
data_point["confidence_upper"] = round(result.confidence_upper[i], 2)
|
| 472 |
+
|
| 473 |
+
forecast_data.append(data_point)
|
| 474 |
+
|
| 475 |
+
return forecast_data
|
| 476 |
+
|
| 477 |
+
except Exception as e:
|
| 478 |
+
self.logger.error(f"Forecast data preparation failed: {str(e)}")
|
| 479 |
+
raise
|
| 480 |
+
|
| 481 |
+
def calculate_revenue_projection(
|
| 482 |
+
self,
|
| 483 |
+
forecast_data: List[Dict[str, Any]],
|
| 484 |
+
selling_price: float,
|
| 485 |
+
historical_data: pd.DataFrame
|
| 486 |
+
) -> List[Dict[str, Any]]:
|
| 487 |
+
"""Calculate revenue projections"""
|
| 488 |
+
try:
|
| 489 |
+
# Use average quantity from historical data
|
| 490 |
+
avg_quantity = historical_data['quantity'].mean()
|
| 491 |
+
|
| 492 |
+
revenue_projection = []
|
| 493 |
+
for point in forecast_data:
|
| 494 |
+
projected_quantity = avg_quantity
|
| 495 |
+
projected_revenue = projected_quantity * selling_price
|
| 496 |
+
|
| 497 |
+
projection = {
|
| 498 |
+
"date": point["date"],
|
| 499 |
+
"projected_quantity": round(float(projected_quantity), 2),
|
| 500 |
+
"selling_price": round(float(selling_price), 2),
|
| 501 |
+
"projected_revenue": round(float(projected_revenue), 2)
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
# Add confidence intervals if available
|
| 505 |
+
if "confidence_lower" in point:
|
| 506 |
+
projection["confidence_lower"] = round(point["confidence_lower"] * projected_quantity, 2)
|
| 507 |
+
if "confidence_upper" in point:
|
| 508 |
+
projection["confidence_upper"] = round(point["confidence_upper"] * projected_quantity, 2)
|
| 509 |
+
|
| 510 |
+
revenue_projection.append(projection)
|
| 511 |
+
|
| 512 |
+
return revenue_projection
|
| 513 |
+
|
| 514 |
+
except Exception as e:
|
| 515 |
+
self.logger.error(f"Revenue projection calculation failed: {str(e)}")
|
| 516 |
+
return []
|
| 517 |
+
|
| 518 |
+
def generate_summary(
|
| 519 |
+
self,
|
| 520 |
+
forecast_data: List[Dict[str, Any]],
|
| 521 |
+
historical_data: pd.DataFrame,
|
| 522 |
+
models_used: List[str],
|
| 523 |
+
scenario: str
|
| 524 |
+
) -> str:
|
| 525 |
+
"""Generate AI-like summary of forecast results"""
|
| 526 |
+
try:
|
| 527 |
+
# Calculate key metrics
|
| 528 |
+
forecast_values = [point["predicted_value"] for point in forecast_data]
|
| 529 |
+
avg_forecast = np.mean(forecast_values)
|
| 530 |
+
avg_historical = historical_data['price'].mean()
|
| 531 |
+
|
| 532 |
+
trend = "increasing" if avg_forecast > avg_historical else "decreasing"
|
| 533 |
+
change_percent = abs((avg_forecast - avg_historical) / avg_historical * 100)
|
| 534 |
+
|
| 535 |
+
# Generate summary
|
| 536 |
+
summary = f"""# Price Forecast Summary
|
| 537 |
+
|
| 538 |
+
## Overview
|
| 539 |
+
Based on historical demand data, the forecast shows a **{trend}** trend over the next {len(forecast_data)} days using {scenario} scenario.
|
| 540 |
+
|
| 541 |
+
## Key Metrics
|
| 542 |
+
- **Average Historical Price**: ${avg_historical:.2f}
|
| 543 |
+
- **Average Forecasted Price**: ${avg_forecast:.2f}
|
| 544 |
+
- **Expected Change**: {change_percent:.1f}% {trend}
|
| 545 |
+
- **Models Used**: {', '.join(models_used)}
|
| 546 |
+
- **Forecast Horizon**: {len(forecast_data)} days
|
| 547 |
+
|
| 548 |
+
## Analysis
|
| 549 |
+
The forecast combines multiple statistical and machine learning models to provide reliable predictions. Confidence intervals are included to help assess prediction uncertainty.
|
| 550 |
+
|
| 551 |
+
## Recommendations
|
| 552 |
+
{'Consider increasing inventory to meet potential higher demand.' if trend == 'increasing' else 'Monitor market conditions closely as prices may decline.'}
|
| 553 |
+
Track actual prices against this forecast and adjust strategies accordingly."""
|
| 554 |
+
|
| 555 |
+
return summary
|
| 556 |
+
|
| 557 |
+
except Exception as e:
|
| 558 |
+
self.logger.error(f"Summary generation failed: {str(e)}")
|
| 559 |
+
return "Forecast summary generation failed."
|
| 560 |
+
|
| 561 |
+
def calculate_overall_confidence(self, forecast_data: List[Dict[str, Any]]) -> Optional[float]:
|
| 562 |
+
"""Calculate overall confidence score"""
|
| 563 |
+
try:
|
| 564 |
+
confidence_scores = []
|
| 565 |
+
|
| 566 |
+
for point in forecast_data:
|
| 567 |
+
if "confidence_lower" in point and "confidence_upper" in point:
|
| 568 |
+
lower = point["confidence_lower"]
|
| 569 |
+
upper = point["confidence_upper"]
|
| 570 |
+
predicted = point["predicted_value"]
|
| 571 |
+
|
| 572 |
+
# Calculate confidence interval width relative to prediction
|
| 573 |
+
if predicted != 0:
|
| 574 |
+
interval_width = (upper - lower) / predicted
|
| 575 |
+
# Convert to confidence score (0-100)
|
| 576 |
+
confidence = max(0, min(100, 100 - (interval_width * 50)))
|
| 577 |
+
confidence_scores.append(confidence)
|
| 578 |
+
|
| 579 |
+
if confidence_scores:
|
| 580 |
+
return round(np.mean(confidence_scores), 1)
|
| 581 |
+
|
| 582 |
+
return None
|
| 583 |
+
|
| 584 |
+
except Exception as e:
|
| 585 |
+
self.logger.error(f"Confidence calculation failed: {str(e)}")
|
| 586 |
+
return None
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core FastAPI dependencies
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
pydantic==2.5.0
|
| 5 |
+
|
| 6 |
+
# Data processing
|
| 7 |
+
pandas==2.1.4
|
| 8 |
+
numpy==1.26.2
|
| 9 |
+
|
| 10 |
+
# Machine Learning & Statistics
|
| 11 |
+
scikit-learn==1.3.2
|
| 12 |
+
statsmodels==0.14.0
|
| 13 |
+
catboost==1.2.2
|
| 14 |
+
joblib==1.3.2
|
| 15 |
+
|
| 16 |
+
# Utilities
|
| 17 |
+
python-multipart==0.0.6
|
| 18 |
+
httpx==0.25.2
|
| 19 |
+
requests==2.31.0
|
| 20 |
+
|
| 21 |
+
# Optional: For development and testing (can be removed for production)
|
| 22 |
+
pytest==7.4.3
|
| 23 |
+
pytest-asyncio==0.21.1
|
run.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Development script for AgriPredict Analysis Service
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def install_dependencies():
|
| 12 |
+
"""Install Python dependencies"""
|
| 13 |
+
print("Installing dependencies...")
|
| 14 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)
|
| 15 |
+
|
| 16 |
+
def run_service():
|
| 17 |
+
"""Run the FastAPI service"""
|
| 18 |
+
print("Starting AgriPredict Analysis Service...")
|
| 19 |
+
print("API will be available at: http://localhost:8000")
|
| 20 |
+
print("API documentation at: http://localhost:8000/docs")
|
| 21 |
+
|
| 22 |
+
# Set environment variables
|
| 23 |
+
env = os.environ.copy()
|
| 24 |
+
env["PYTHONPATH"] = str(Path(__file__).parent)
|
| 25 |
+
|
| 26 |
+
subprocess.run([sys.executable, "main.py"], env=env)
|
| 27 |
+
|
| 28 |
+
def train_model():
|
| 29 |
+
"""Train the CatBoost model with artificial data"""
|
| 30 |
+
print("Training CatBoost model with artificial data...")
|
| 31 |
+
subprocess.run([sys.executable, "train_catboost.py"], check=True)
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
if len(sys.argv) < 2:
|
| 35 |
+
print("Usage: python run.py [install|run|train]")
|
| 36 |
+
print(" install - Install dependencies")
|
| 37 |
+
print(" run - Run the service")
|
| 38 |
+
print(" train - Train the CatBoost model")
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
command = sys.argv[1].lower()
|
| 42 |
+
|
| 43 |
+
if command == "install":
|
| 44 |
+
install_dependencies()
|
| 45 |
+
elif command == "run":
|
| 46 |
+
run_service()
|
| 47 |
+
elif command == "train":
|
| 48 |
+
train_model()
|
| 49 |
+
else:
|
| 50 |
+
print(f"Unknown command: {command}")
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
main()
|
test_api.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Example script showing how to use the AgriPredict Analysis Service API
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
def generate_sample_data(days: int = 30):
|
| 12 |
+
"""Generate sample historical data for testing"""
|
| 13 |
+
data = []
|
| 14 |
+
base_date = datetime.now() - timedelta(days=days)
|
| 15 |
+
|
| 16 |
+
for i in range(days):
|
| 17 |
+
date = base_date + timedelta(days=i)
|
| 18 |
+
# Generate realistic agricultural data
|
| 19 |
+
quantity = random.randint(50, 150) + random.randint(-20, 20)
|
| 20 |
+
price = round(20 + random.uniform(-5, 5), 2)
|
| 21 |
+
|
| 22 |
+
data.append({
|
| 23 |
+
"date": date.strftime("%Y-%m-%d"),
|
| 24 |
+
"quantity": max(1, quantity), # Ensure positive quantity
|
| 25 |
+
"price": max(5, price) # Ensure positive price
|
| 26 |
+
})
|
| 27 |
+
|
| 28 |
+
return data
|
| 29 |
+
|
| 30 |
+
def test_health_check(base_url: str = "http://localhost:8000"):
|
| 31 |
+
"""Test the health check endpoint"""
|
| 32 |
+
print("Testing health check...")
|
| 33 |
+
try:
|
| 34 |
+
response = requests.get(f"{base_url}/health")
|
| 35 |
+
if response.status_code == 200:
|
| 36 |
+
print("✅ Health check passed")
|
| 37 |
+
print(f"Response: {response.json()}")
|
| 38 |
+
else:
|
| 39 |
+
print(f"❌ Health check failed: {response.status_code}")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"❌ Health check error: {e}")
|
| 42 |
+
|
| 43 |
+
def test_list_models(base_url: str = "http://localhost:8000"):
|
| 44 |
+
"""Test the list models endpoint"""
|
| 45 |
+
print("\nTesting list models...")
|
| 46 |
+
try:
|
| 47 |
+
response = requests.get(f"{base_url}/models")
|
| 48 |
+
if response.status_code == 200:
|
| 49 |
+
print("✅ Models list retrieved")
|
| 50 |
+
models = response.json()["models"]
|
| 51 |
+
print(f"Available models: {len(models)}")
|
| 52 |
+
for model in models:
|
| 53 |
+
print(f" - {model['name']} ({model['id']})")
|
| 54 |
+
else:
|
| 55 |
+
print(f"❌ Models list failed: {response.status_code}")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"❌ Models list error: {e}")
|
| 58 |
+
|
| 59 |
+
def test_forecast_generation(base_url: str = "http://localhost:8000"):
|
| 60 |
+
"""Test forecast generation"""
|
| 61 |
+
print("\nTesting forecast generation...")
|
| 62 |
+
|
| 63 |
+
# Generate sample data
|
| 64 |
+
historical_data = generate_sample_data(30)
|
| 65 |
+
|
| 66 |
+
# Prepare forecast request
|
| 67 |
+
forecast_request = {
|
| 68 |
+
"product_id": "sample_crop",
|
| 69 |
+
"historical_data": historical_data,
|
| 70 |
+
"days": 14,
|
| 71 |
+
"selling_price": 25.0,
|
| 72 |
+
"models": ["ensemble"],
|
| 73 |
+
"include_confidence": True,
|
| 74 |
+
"scenario": "realistic"
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
response = requests.post(
|
| 79 |
+
f"{base_url}/forecast",
|
| 80 |
+
json=forecast_request,
|
| 81 |
+
headers={"Content-Type": "application/json"}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if response.status_code == 200:
|
| 85 |
+
print("✅ Forecast generated successfully")
|
| 86 |
+
result = response.json()
|
| 87 |
+
|
| 88 |
+
print(f"Models used: {result['models_used']}")
|
| 89 |
+
print(f"Forecast points: {len(result['forecast_data'])}")
|
| 90 |
+
print(f"Confidence: {result.get('confidence', 'N/A')}%")
|
| 91 |
+
|
| 92 |
+
if result.get('revenue_projection'):
|
| 93 |
+
print(f"Revenue projections: {len(result['revenue_projection'])}")
|
| 94 |
+
|
| 95 |
+
# Show first few forecast points
|
| 96 |
+
print("\nFirst 3 forecast points:")
|
| 97 |
+
for i, point in enumerate(result['forecast_data'][:3]):
|
| 98 |
+
print(f" Day {i+1}: {point['predicted_value']:.2f} "
|
| 99 |
+
f"(±{point.get('confidence_upper', 0) - point.get('confidence_lower', 0):.2f})")
|
| 100 |
+
|
| 101 |
+
else:
|
| 102 |
+
print(f"❌ Forecast failed: {response.status_code}")
|
| 103 |
+
print(f"Error: {response.text}")
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"❌ Forecast error: {e}")
|
| 107 |
+
|
| 108 |
+
def main():
|
| 109 |
+
"""Main test function"""
|
| 110 |
+
print("🚀 AgriPredict Analysis Service API Test")
|
| 111 |
+
print("=" * 50)
|
| 112 |
+
|
| 113 |
+
# Test with local service (change URL for deployed service)
|
| 114 |
+
base_url = "http://localhost:8000"
|
| 115 |
+
|
| 116 |
+
# Run tests
|
| 117 |
+
test_health_check(base_url)
|
| 118 |
+
test_list_models(base_url)
|
| 119 |
+
test_forecast_generation(base_url)
|
| 120 |
+
|
| 121 |
+
print("\n" + "=" * 50)
|
| 122 |
+
print("API test completed!")
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
train_catboost.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CatBoost Model Training Script for AgriPredict
|
| 3 |
+
This script demonstrates how to train the CatBoost model with artificial agricultural data.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from catboost import CatBoostRegressor, Pool
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
| 12 |
+
import joblib
|
| 13 |
+
import os
|
| 14 |
+
from typing import Dict, Any
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
# Setup logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class CatBoostTrainer:
|
| 22 |
+
"""CatBoost model trainer for agricultural demand forecasting"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.model = None
|
| 26 |
+
self.feature_names = None
|
| 27 |
+
|
| 28 |
+
def generate_artificial_data(self, n_samples: int = 1000) -> pd.DataFrame:
|
| 29 |
+
"""
|
| 30 |
+
Generate artificial agricultural data for training
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
n_samples: Number of samples to generate
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
DataFrame with artificial agricultural data
|
| 37 |
+
"""
|
| 38 |
+
logger.info(f"Generating {n_samples} artificial data samples")
|
| 39 |
+
|
| 40 |
+
# Generate date range
|
| 41 |
+
start_date = datetime(2023, 1, 1)
|
| 42 |
+
dates = [start_date + timedelta(days=i) for i in range(n_samples)]
|
| 43 |
+
|
| 44 |
+
np.random.seed(42) # For reproducible results
|
| 45 |
+
|
| 46 |
+
data = []
|
| 47 |
+
|
| 48 |
+
for date in dates:
|
| 49 |
+
# Seasonal patterns
|
| 50 |
+
day_of_year = date.timetuple().tm_yday
|
| 51 |
+
seasonal_factor = 1 + 0.3 * np.sin(2 * np.pi * day_of_year / 365)
|
| 52 |
+
|
| 53 |
+
# Base demand with seasonal variation
|
| 54 |
+
base_quantity = np.random.normal(100, 20) * seasonal_factor
|
| 55 |
+
|
| 56 |
+
# Price influenced by season and demand
|
| 57 |
+
base_price = 25 + 5 * np.sin(2 * np.pi * day_of_year / 365)
|
| 58 |
+
price_noise = np.random.normal(0, 2)
|
| 59 |
+
price = base_price + price_noise
|
| 60 |
+
|
| 61 |
+
# Add some correlation between price and quantity
|
| 62 |
+
quantity_noise = np.random.normal(0, 15)
|
| 63 |
+
quantity = base_quantity + quantity_noise - 0.1 * (price - 25)
|
| 64 |
+
|
| 65 |
+
# Ensure positive values
|
| 66 |
+
quantity = max(1, quantity)
|
| 67 |
+
price = max(5, price)
|
| 68 |
+
|
| 69 |
+
data.append({
|
| 70 |
+
'date': date,
|
| 71 |
+
'quantity': round(quantity, 2),
|
| 72 |
+
'price': round(price, 2),
|
| 73 |
+
'day_of_week': date.weekday(),
|
| 74 |
+
'month': date.month,
|
| 75 |
+
'day_of_month': date.day,
|
| 76 |
+
'quarter': (date.month - 1) // 3 + 1,
|
| 77 |
+
'is_weekend': 1 if date.weekday() >= 5 else 0,
|
| 78 |
+
'season': self._get_season(date.month)
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
df = pd.DataFrame(data)
|
| 82 |
+
|
| 83 |
+
# Add lag features
|
| 84 |
+
for lag in [1, 7, 14, 30]:
|
| 85 |
+
df[f'price_lag_{lag}'] = df['price'].shift(lag)
|
| 86 |
+
df[f'quantity_lag_{lag}'] = df['quantity'].shift(lag)
|
| 87 |
+
|
| 88 |
+
# Add rolling statistics
|
| 89 |
+
for window in [7, 14, 30]:
|
| 90 |
+
df[f'price_rolling_mean_{window}'] = df['price'].rolling(window).mean()
|
| 91 |
+
df[f'price_rolling_std_{window}'] = df['price'].rolling(window).std()
|
| 92 |
+
df[f'quantity_rolling_mean_{window}'] = df['quantity'].rolling(window).mean()
|
| 93 |
+
|
| 94 |
+
# Add price change features
|
| 95 |
+
df['price_change'] = df['price'].pct_change()
|
| 96 |
+
df['price_change_7d'] = df['price'].pct_change(7)
|
| 97 |
+
|
| 98 |
+
# Drop rows with NaN values
|
| 99 |
+
df = df.dropna().reset_index(drop=True)
|
| 100 |
+
|
| 101 |
+
logger.info(f"Generated dataset with {len(df)} samples and {len(df.columns)} features")
|
| 102 |
+
return df
|
| 103 |
+
|
| 104 |
+
def _get_season(self, month: int) -> str:
|
| 105 |
+
"""Get season based on month"""
|
| 106 |
+
if month in [12, 1, 2]:
|
| 107 |
+
return 'winter'
|
| 108 |
+
elif month in [3, 4, 5]:
|
| 109 |
+
return 'spring'
|
| 110 |
+
elif month in [6, 7, 8]:
|
| 111 |
+
return 'summer'
|
| 112 |
+
else:
|
| 113 |
+
return 'fall'
|
| 114 |
+
|
| 115 |
+
def prepare_features(self, df: pd.DataFrame) -> tuple:
|
| 116 |
+
"""
|
| 117 |
+
Prepare features for training
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
df: Input DataFrame
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple of (X, y, feature_names)
|
| 124 |
+
"""
|
| 125 |
+
# Define feature columns (exclude target and non-feature columns)
|
| 126 |
+
exclude_cols = ['date', 'quantity', 'price']
|
| 127 |
+
feature_cols = [col for col in df.columns if col not in exclude_cols]
|
| 128 |
+
|
| 129 |
+
# Prepare features and target
|
| 130 |
+
X = df[feature_cols]
|
| 131 |
+
y = df['price'] # We're predicting price
|
| 132 |
+
|
| 133 |
+
logger.info(f"Prepared {len(feature_cols)} features for training")
|
| 134 |
+
return X, y, feature_cols
|
| 135 |
+
|
| 136 |
+
def train_model(self, X_train, y_train, X_val=None, y_val=None, **kwargs) -> CatBoostRegressor:
|
| 137 |
+
"""
|
| 138 |
+
Train CatBoost model
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
X_train: Training features
|
| 142 |
+
y_train: Training target
|
| 143 |
+
X_val: Validation features (optional)
|
| 144 |
+
y_val: Validation target (optional)
|
| 145 |
+
**kwargs: Additional CatBoost parameters
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Trained CatBoost model
|
| 149 |
+
"""
|
| 150 |
+
# Default parameters
|
| 151 |
+
default_params = {
|
| 152 |
+
'iterations': 1000,
|
| 153 |
+
'learning_rate': 0.1,
|
| 154 |
+
'depth': 6,
|
| 155 |
+
'loss_function': 'MAE',
|
| 156 |
+
'eval_metric': 'MAE',
|
| 157 |
+
'random_seed': 42,
|
| 158 |
+
'verbose': 100,
|
| 159 |
+
'early_stopping_rounds': 50
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Update with custom parameters
|
| 163 |
+
default_params.update(kwargs)
|
| 164 |
+
|
| 165 |
+
# Create model
|
| 166 |
+
model = CatBoostRegressor(**default_params)
|
| 167 |
+
|
| 168 |
+
# Prepare data
|
| 169 |
+
train_pool = Pool(X_train, y_train)
|
| 170 |
+
|
| 171 |
+
if X_val is not None and y_val is not None:
|
| 172 |
+
val_pool = Pool(X_val, y_val)
|
| 173 |
+
model.fit(train_pool, eval_set=val_pool)
|
| 174 |
+
else:
|
| 175 |
+
model.fit(train_pool)
|
| 176 |
+
|
| 177 |
+
self.model = model
|
| 178 |
+
self.feature_names = list(X_train.columns)
|
| 179 |
+
|
| 180 |
+
logger.info(f"Trained CatBoost model with {model.tree_count_} trees")
|
| 181 |
+
return model
|
| 182 |
+
|
| 183 |
+
def evaluate_model(self, X_test, y_test) -> Dict[str, float]:
|
| 184 |
+
"""
|
| 185 |
+
Evaluate model performance
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
X_test: Test features
|
| 189 |
+
y_test: Test target
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Dictionary with evaluation metrics
|
| 193 |
+
"""
|
| 194 |
+
if self.model is None:
|
| 195 |
+
raise ValueError("Model not trained yet")
|
| 196 |
+
|
| 197 |
+
# Make predictions
|
| 198 |
+
y_pred = self.model.predict(X_test)
|
| 199 |
+
|
| 200 |
+
# Calculate metrics
|
| 201 |
+
mae = mean_absolute_error(y_test, y_pred)
|
| 202 |
+
mse = mean_squared_error(y_test, y_pred)
|
| 203 |
+
rmse = np.sqrt(mse)
|
| 204 |
+
|
| 205 |
+
# Calculate MAPE (Mean Absolute Percentage Error)
|
| 206 |
+
mape = np.mean(np.abs((y_test - y_pred) / y_test)) * 100
|
| 207 |
+
|
| 208 |
+
metrics = {
|
| 209 |
+
'mae': mae,
|
| 210 |
+
'mse': mse,
|
| 211 |
+
'rmse': rmse,
|
| 212 |
+
'mape': mape
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
logger.info(".2f")
|
| 216 |
+
return metrics
|
| 217 |
+
|
| 218 |
+
def save_model(self, filepath: str):
|
| 219 |
+
"""
|
| 220 |
+
Save trained model to file
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
filepath: Path to save the model
|
| 224 |
+
"""
|
| 225 |
+
if self.model is None:
|
| 226 |
+
raise ValueError("Model not trained yet")
|
| 227 |
+
|
| 228 |
+
# Create directory if it doesn't exist
|
| 229 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 230 |
+
|
| 231 |
+
# Save model
|
| 232 |
+
joblib.dump({
|
| 233 |
+
'model': self.model,
|
| 234 |
+
'feature_names': self.feature_names,
|
| 235 |
+
'training_date': datetime.now().isoformat()
|
| 236 |
+
}, filepath)
|
| 237 |
+
|
| 238 |
+
logger.info(f"Model saved to {filepath}")
|
| 239 |
+
|
| 240 |
+
def load_model(self, filepath: str):
|
| 241 |
+
"""
|
| 242 |
+
Load trained model from file
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
filepath: Path to the saved model
|
| 246 |
+
"""
|
| 247 |
+
if not os.path.exists(filepath):
|
| 248 |
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
| 249 |
+
|
| 250 |
+
# Load model
|
| 251 |
+
model_data = joblib.load(filepath)
|
| 252 |
+
self.model = model_data['model']
|
| 253 |
+
self.feature_names = model_data['feature_names']
|
| 254 |
+
|
| 255 |
+
logger.info(f"Model loaded from {filepath}")
|
| 256 |
+
|
| 257 |
+
def predict(self, features: pd.DataFrame) -> np.ndarray:
|
| 258 |
+
"""
|
| 259 |
+
Make predictions with trained model
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
features: Input features
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Predictions array
|
| 266 |
+
"""
|
| 267 |
+
if self.model is None:
|
| 268 |
+
raise ValueError("Model not trained or loaded yet")
|
| 269 |
+
|
| 270 |
+
# Ensure features are in correct order
|
| 271 |
+
if self.feature_names:
|
| 272 |
+
features = features[self.feature_names]
|
| 273 |
+
|
| 274 |
+
return self.model.predict(features)
|
| 275 |
+
|
| 276 |
+
def main():
|
| 277 |
+
"""Main training function"""
|
| 278 |
+
logger.info("Starting CatBoost model training")
|
| 279 |
+
|
| 280 |
+
# Initialize trainer
|
| 281 |
+
trainer = CatBoostTrainer()
|
| 282 |
+
|
| 283 |
+
# Generate artificial data
|
| 284 |
+
df = trainer.generate_artificial_data(n_samples=2000)
|
| 285 |
+
|
| 286 |
+
# Prepare features
|
| 287 |
+
X, y, feature_names = trainer.prepare_features(df)
|
| 288 |
+
|
| 289 |
+
# Split data
|
| 290 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 291 |
+
X, y, test_size=0.2, random_state=42
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Further split training data for validation
|
| 295 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 296 |
+
X_train, y_train, test_size=0.2, random_state=42
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Train model
|
| 300 |
+
model = trainer.train_model(X_train, y_train, X_val, y_val)
|
| 301 |
+
|
| 302 |
+
# Evaluate model
|
| 303 |
+
metrics = trainer.evaluate_model(X_test, y_test)
|
| 304 |
+
|
| 305 |
+
# Save model
|
| 306 |
+
model_path = "models/catboost_model.pkl"
|
| 307 |
+
trainer.save_model(model_path)
|
| 308 |
+
|
| 309 |
+
logger.info("Training completed successfully!")
|
| 310 |
+
logger.info(f"Model saved to: {model_path}")
|
| 311 |
+
logger.info(f"Test Metrics: {metrics}")
|
| 312 |
+
|
| 313 |
+
return trainer
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
trained_trainer = main()
|
utils/config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for AgriPredict Analysis Service
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
class Settings:
|
| 9 |
+
"""Application settings"""
|
| 10 |
+
|
| 11 |
+
# API Settings
|
| 12 |
+
API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
|
| 13 |
+
API_PORT: int = int(os.getenv("PORT", 8000))
|
| 14 |
+
API_WORKERS: int = int(os.getenv("API_WORKERS", 1))
|
| 15 |
+
|
| 16 |
+
# CORS Settings
|
| 17 |
+
ALLOWED_ORIGINS: List[str] = [
|
| 18 |
+
"http://localhost:3000",
|
| 19 |
+
"http://localhost:3001",
|
| 20 |
+
"https://*.huggingface.co",
|
| 21 |
+
"https://huggingface.co",
|
| 22 |
+
os.getenv("FRONTEND_URL", "*")
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Model Settings
|
| 26 |
+
DEFAULT_MODELS: List[str] = ["ensemble"]
|
| 27 |
+
MAX_FORECAST_DAYS: int = 365
|
| 28 |
+
MIN_HISTORICAL_DATA_POINTS: int = 3
|
| 29 |
+
|
| 30 |
+
# CatBoost Settings (for future training)
|
| 31 |
+
CATBOOST_ITERATIONS: int = 100
|
| 32 |
+
CATBOOST_LEARNING_RATE: float = 0.1
|
| 33 |
+
CATBOOST_DEPTH: int = 6
|
| 34 |
+
CATBOOST_VERBOSE: bool = False
|
| 35 |
+
|
| 36 |
+
# Logging
|
| 37 |
+
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
| 38 |
+
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 39 |
+
|
| 40 |
+
# Data Processing
|
| 41 |
+
DATE_FORMAT: str = "%Y-%m-%d"
|
| 42 |
+
MAX_DATA_POINTS: int = 10000
|
| 43 |
+
|
| 44 |
+
# Global settings instance
|
| 45 |
+
settings = Settings()
|
utils/logger.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for AgriPredict Analysis Service
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import sys
|
| 7 |
+
from utils.config import settings
|
| 8 |
+
|
| 9 |
+
def setup_logger(name: str) -> logging.Logger:
|
| 10 |
+
"""Setup logger with proper configuration"""
|
| 11 |
+
logger = logging.getLogger(name)
|
| 12 |
+
logger.setLevel(getattr(logging, settings.LOG_LEVEL))
|
| 13 |
+
|
| 14 |
+
# Remove existing handlers to avoid duplicates
|
| 15 |
+
logger.handlers.clear()
|
| 16 |
+
|
| 17 |
+
# Create console handler
|
| 18 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 19 |
+
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL))
|
| 20 |
+
|
| 21 |
+
# Create formatter
|
| 22 |
+
formatter = logging.Formatter(settings.LOG_FORMAT)
|
| 23 |
+
console_handler.setFormatter(formatter)
|
| 24 |
+
|
| 25 |
+
# Add handler to logger
|
| 26 |
+
logger.addHandler(console_handler)
|
| 27 |
+
|
| 28 |
+
return logger
|
| 29 |
+
|
| 30 |
+
# Global logger instance
|
| 31 |
+
logger = setup_logger(__name__)
|