import os import torch import numpy as np from fastapi import FastAPI, Request from chronos import ChronosPipeline # This works ONLY if chronos-forecasting is installed import uvicorn app = FastAPI() # Global variable for the model pipeline = None @app.on_event("startup") def load_model(): global pipeline print("Loading Chronos model...") # Using 'tiny' for stability on HF basic instances # Ensure dtype is torch.bfloat16 for efficiency pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-tiny", device_map="cpu", # Use "cpu" if no GPU, or "auto" if you have a T4 torch_dtype=torch.bfloat16 ) print("Model loaded successfully.") @app.get("/") def home(): return {"status": "Model is running", "info": "Send POST to /predict"} @app.post("/predict") async def get_forecast(request: Request): global pipeline try: body = await request.json() history = body.get("history", []) user_rr = float(body.get("rr_ratio", 2.0)) if not history: return {"error": "History array is empty"} # 1. Convert to tensor and add batch dimension [1, seq_len] # Chronos expects a batch dimension context = torch.tensor(history, dtype=torch.float32).unsqueeze(0) # 2. Prediction logic # num_samples=20 gives enough for quantiles without killing RAM forecast = pipeline.predict(context, prediction_length=12, num_samples=20) # 3. Extract results (remove batch dimension for processing) # forecast shape is [batch, samples, horizon] -> [samples, horizon] forecast_np = forecast[0].numpy() low_bound = np.percentile(forecast_np, 10, axis=0) median_pred = np.percentile(forecast_np, 50, axis=0) high_bound = np.percentile(forecast_np, 90, axis=0) # 4. Trading Logic entry_price = float(history[-1]) first_pred = float(median_pred[0]) # Stop Loss (SL) at the 10th percentile sl = float(low_bound[0]) if sl >= entry_price: sl = entry_price * 0.98 # Safety fallback # Take Profit (TP) risk = entry_price - sl tp = entry_price + (risk * user_rr) return { "prediction": median_pred.tolist(), "upper_bound": high_bound.tolist(), "lower_bound": low_bound.tolist(), "suggested_sl": round(sl, 4), "suggested_tp": round(tp, 4), "verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry" } except Exception as e: return {"error": str(e)} if __name__ == "__main__": # Port 7860 is required for Hugging Face Spaces uvicorn.run(app, host="0.0.0.0", port=7860)