sac-trader-bot / app.py
monstaws's picture
Update app.py
03eaaca verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
# ==========================================
# 1. NETWORK DEFINITION
# ==========================================
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)
self.LOG_STD_MIN = -20
self.LOG_STD_MAX = 2
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
mean = self.mean(x)
log_std = self.log_std(x)
log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
return mean, log_std
# ==========================================
# 2. DYNAMIC SCALER
# ==========================================
class DynamicScaler:
def fit_transform(self, data):
X = np.array(data, dtype=np.float32)
mean = np.mean(X, axis=0)
scale = np.std(X, axis=0)
scale[scale < 1e-8] = 1.0
X_scaled = (X - mean) / scale
return np.clip(X_scaled, -5, 5)
# ==========================================
# 3. SETUP
# ==========================================
app = FastAPI()
# !!! THE FIX: 53 DIMENSIONS !!!
STATE_DIM = 53
ACTION_DIM = 1
HIDDEN_DIM = 256
device = torch.device("cpu")
actor = Actor(STATE_DIM, ACTION_DIM, HIDDEN_DIM).to(device)
scaler = DynamicScaler()
# Load Model
try:
print("Loading model...")
checkpoint = torch.load("sac_v9_pytorch_best_eval.pt", map_location=device)
actor.load_state_dict(checkpoint['actor'])
actor.eval()
print(f"✅ Model loaded! Expecting {STATE_DIM} inputs.")
except Exception as e:
print(f"❌ Error loading model: {e}")
# ==========================================
# 4. ENDPOINT
# ==========================================
class InputData(BaseModel):
history: List[List[float]]
portfolio: List[float]
@app.post("/predict")
def predict(data: InputData):
try:
# 1. Scale History
if len(data.history) < 2:
return {"error": "Need at least 2 rows of history"}
scaled_history = scaler.fit_transform(data.history)
current_market = scaled_history[-1] # This likely has ~96 features
# 2. Prepare Portfolio (5 items)
portfolio = np.array(data.portfolio, dtype=np.float32)
# 3. Smart Concatenation to match 53 Inputs
# The model needs 53 inputs.
# Structure: [Market Features (48)] + [Portfolio (5)]
needed_market = STATE_DIM - len(portfolio) # 48
# Take the first 48 market features (most important ones like RSI, MACD usually first)
if len(current_market) >= needed_market:
market_part = current_market[:needed_market]
else:
# Pad with zeros if we somehow have less
padding = np.zeros(needed_market - len(current_market))
market_part = np.concatenate([current_market, padding])
state = np.concatenate([market_part, portfolio])
# 4. Inference
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
with torch.no_grad():
mean, _ = actor(state_tensor)
action = torch.tanh(mean).item()
return {
"action": action,
"signal": "BUY" if action > 0 else "SELL",
"confidence": abs(action)
}
except Exception as e:
return {"error": str(e)}