CosmicVoyage_RL / inference.py
ANdte
Update Dockerfile and inference configuration after local testing
aee4ad2
raw
history blame
3.38 kB
import torch
import numpy as np
from stable_baselines3 import PPO
import base64
import io
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from typing import List
# Initialize FastAPI
app = FastAPI()
# Define the /predict endpoint before mounting StaticFiles
class Pipeline:
def __init__(self):
# Load the model
print("Loading the model...")
self.model = PPO.load("model_RL_Cosmic_latest", device='cpu', custom_objects={"map_location": torch.device('cpu')})
print("Model loaded successfully.")
def __call__(self, inputs):
print("Pipeline called with inputs:", inputs)
obs_data = inputs.get("observation")
if obs_data is None:
print("No observation provided.")
return {"error": "No observation provided"}
obs = self.preprocess_observation(obs_data)
if isinstance(obs, dict) and "error" in obs:
print("Error in preprocessing observation:", obs["error"])
return obs
print("Observation preprocessed successfully. Shape:", obs.shape)
action, _ = self.model.predict(obs, deterministic=True)
print("Model predicted action:", action)
return {"action": action.tolist()}
def preprocess_observation(self, obs_data):
print("Preprocessing observation...")
try:
if isinstance(obs_data, list) and len(obs_data) == 8:
frames = []
for idx, encoded_frame in enumerate(obs_data):
print(f"Decoding frame {idx+1}/8")
img_bytes = base64.b64decode(encoded_frame.split(",")[1])
img = Image.open(io.BytesIO(img_bytes)).convert('L').resize((100, 150))
img_array = np.array(img, dtype=np.float32) / 255.0
frames.append(img_array)
obs = np.stack(frames, axis=0) # Shape: (8, 150, 100)
obs = obs[np.newaxis, ...] # Add batch dimension
return obs
else:
error_msg = "Invalid observation format or insufficient frames"
print(error_msg)
return {"error": error_msg}
except Exception as e:
error_msg = f"Failed to process observation: {e}"
print(error_msg)
return {"error": error_msg}
pipeline = Pipeline()
class ObservationInput(BaseModel):
observation: List[str]
@app.post("/predict")
def predict(input_data: ObservationInput):
print("Received a prediction request.")
print("Received observation data type:", type(input_data.observation))
result = pipeline({"observation": input_data.observation})
if "error" in result:
print("Error in pipeline:", result["error"])
raise HTTPException(status_code=400, detail=result["error"])
print("Returning result:", result)
return result
# Mount the static files after defining other routes
app.mount("/", StaticFiles(directory="static", html=True), name="static")
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this to the specific domains if needed
allow_credentials=True,
allow_methods=["*"], # Allow all methods (including POST)
allow_headers=["*"],
)