Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from stable_baselines3 import PPO | |
| import base64 | |
| import io | |
| from PIL import Image | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Define the /predict endpoint before mounting StaticFiles | |
| class Pipeline: | |
| def __init__(self): | |
| # Load the model | |
| print("Loading the model...") | |
| self.model = PPO.load("model_RL_Cosmic_latest", device='cpu', custom_objects={"map_location": torch.device('cpu')}) | |
| print("Model loaded successfully.") | |
| def __call__(self, inputs): | |
| print("Pipeline called with inputs:", inputs) | |
| obs_data = inputs.get("observation") | |
| if obs_data is None: | |
| print("No observation provided.") | |
| return {"error": "No observation provided"} | |
| obs = self.preprocess_observation(obs_data) | |
| if isinstance(obs, dict) and "error" in obs: | |
| print("Error in preprocessing observation:", obs["error"]) | |
| return obs | |
| print("Observation preprocessed successfully. Shape:", obs.shape) | |
| action, _ = self.model.predict(obs, deterministic=True) | |
| print("Model predicted action:", action) | |
| return {"action": action.tolist()} | |
| def preprocess_observation(self, obs_data): | |
| print("Preprocessing observation...") | |
| try: | |
| if isinstance(obs_data, list) and len(obs_data) == 8: | |
| frames = [] | |
| for idx, encoded_frame in enumerate(obs_data): | |
| print(f"Decoding frame {idx+1}/8") | |
| img_bytes = base64.b64decode(encoded_frame.split(",")[1]) | |
| img = Image.open(io.BytesIO(img_bytes)).convert('L').resize((100, 150)) | |
| img_array = np.array(img, dtype=np.float32) / 255.0 | |
| frames.append(img_array) | |
| obs = np.stack(frames, axis=0) # Shape: (8, 150, 100) | |
| obs = obs[np.newaxis, ...] # Add batch dimension | |
| return obs | |
| else: | |
| error_msg = "Invalid observation format or insufficient frames" | |
| print(error_msg) | |
| return {"error": error_msg} | |
| except Exception as e: | |
| error_msg = f"Failed to process observation: {e}" | |
| print(error_msg) | |
| return {"error": error_msg} | |
| pipeline = Pipeline() | |
| class ObservationInput(BaseModel): | |
| observation: List[str] | |
| def predict(input_data: ObservationInput): | |
| print("Received a prediction request.") | |
| print("Received observation data type:", type(input_data.observation)) | |
| result = pipeline({"observation": input_data.observation}) | |
| if "error" in result: | |
| print("Error in pipeline:", result["error"]) | |
| raise HTTPException(status_code=400, detail=result["error"]) | |
| print("Returning result:", result) | |
| return result | |
| # Mount the static files after defining other routes | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Adjust this to the specific domains if needed | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods (including POST) | |
| allow_headers=["*"], | |
| ) | |