Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, Request | |
| from chronos import ChronosPipeline # This works ONLY if chronos-forecasting is installed | |
| import uvicorn | |
| app = FastAPI() | |
| # Global variable for the model | |
| pipeline = None | |
| def load_model(): | |
| global pipeline | |
| print("Loading Chronos model...") | |
| # Using 'tiny' for stability on HF basic instances | |
| # Ensure dtype is torch.bfloat16 for efficiency | |
| pipeline = ChronosPipeline.from_pretrained( | |
| "amazon/chronos-t5-tiny", | |
| device_map="cpu", # Use "cpu" if no GPU, or "auto" if you have a T4 | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| print("Model loaded successfully.") | |
| def home(): | |
| return {"status": "Model is running", "info": "Send POST to /predict"} | |
| async def get_forecast(request: Request): | |
| global pipeline | |
| try: | |
| body = await request.json() | |
| history = body.get("history", []) | |
| user_rr = float(body.get("rr_ratio", 2.0)) | |
| if not history: | |
| return {"error": "History array is empty"} | |
| # 1. Convert to tensor and add batch dimension [1, seq_len] | |
| # Chronos expects a batch dimension | |
| context = torch.tensor(history, dtype=torch.float32).unsqueeze(0) | |
| # 2. Prediction logic | |
| # num_samples=20 gives enough for quantiles without killing RAM | |
| forecast = pipeline.predict(context, prediction_length=12, num_samples=20) | |
| # 3. Extract results (remove batch dimension for processing) | |
| # forecast shape is [batch, samples, horizon] -> [samples, horizon] | |
| forecast_np = forecast[0].numpy() | |
| low_bound = np.percentile(forecast_np, 10, axis=0) | |
| median_pred = np.percentile(forecast_np, 50, axis=0) | |
| high_bound = np.percentile(forecast_np, 90, axis=0) | |
| # 4. Trading Logic | |
| entry_price = float(history[-1]) | |
| first_pred = float(median_pred[0]) | |
| # Stop Loss (SL) at the 10th percentile | |
| sl = float(low_bound[0]) | |
| if sl >= entry_price: | |
| sl = entry_price * 0.98 # Safety fallback | |
| # Take Profit (TP) | |
| risk = entry_price - sl | |
| tp = entry_price + (risk * user_rr) | |
| return { | |
| "prediction": median_pred.tolist(), | |
| "upper_bound": high_bound.tolist(), | |
| "lower_bound": low_bound.tolist(), | |
| "suggested_sl": round(sl, 4), | |
| "suggested_tp": round(tp, 4), | |
| "verdict": "Entry Confirmed" if first_pred > entry_price else "Wait for better entry" | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| if __name__ == "__main__": | |
| # Port 7860 is required for Hugging Face Spaces | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |