Spaces:
Runtime error
Runtime error
| import asyncio | |
| import logging | |
| from abc import ABC, abstractmethod | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .joint_config import JointConfig | |
| logger = logging.getLogger(__name__) | |
| class BaseInferenceEngine(ABC): | |
| """ | |
| Base class for all inference engines. | |
| This class provides common functionality for: | |
| - Image preprocessing and normalization | |
| - Joint data handling and validation | |
| - Model loading and management | |
| - Action prediction interface | |
| """ | |
| def __init__( | |
| self, | |
| policy_path: str, | |
| camera_names: list[str], | |
| device: str | None = None, | |
| ): | |
| self.policy_path = policy_path | |
| self.camera_names = camera_names | |
| # Device selection | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| logger.info(f"Using device: {self.device}") | |
| # Model and preprocessing | |
| self.policy = None | |
| self.image_transforms = {} # {camera_name: transform} | |
| self.stats = None # Dataset statistics for normalization | |
| # State tracking | |
| self.is_loaded = False | |
| self.last_images = {} | |
| self.last_joint_positions = None | |
| async def load_policy(self): | |
| """Load the policy model. Must be implemented by subclasses.""" | |
| async def predict( | |
| self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs | |
| ) -> np.ndarray: | |
| """Run inference. Must be implemented by subclasses.""" | |
| def preprocess_images( | |
| self, images: dict[str, np.ndarray] | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Preprocess images for inference. | |
| Args: | |
| images: Dictionary of {camera_name: rgb_image_array} | |
| Returns: | |
| Dictionary of {camera_name: preprocessed_tensor} | |
| """ | |
| processed_images = {} | |
| for camera_name, image in images.items(): | |
| if camera_name not in self.camera_names: | |
| logger.warning(f"Unexpected camera: {camera_name}") | |
| continue | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| if image.dtype != np.uint8: | |
| image = (image * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(image) | |
| else: | |
| pil_image = image | |
| # Apply transforms if available | |
| if camera_name in self.image_transforms: | |
| tensor = self.image_transforms[camera_name](pil_image) | |
| else: | |
| # Default preprocessing: resize to 224x224 and normalize | |
| tensor = self._default_image_transform(pil_image) | |
| processed_images[camera_name] = tensor.to(self.device) | |
| return processed_images | |
| def _default_image_transform(self, image: Image.Image) -> torch.Tensor: | |
| """Default image preprocessing.""" | |
| # Resize to 224x224 (common size for vision models) | |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
| # Convert to tensor and normalize to [0, 1] | |
| tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
| # Rearrange from HWC to CHW | |
| if len(tensor.shape) == 3: | |
| tensor = tensor.permute(2, 0, 1) | |
| return tensor | |
| def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor: | |
| """ | |
| Preprocess joint positions for inference. | |
| Args: | |
| joint_positions: Array of joint positions in standard order | |
| Returns: | |
| Preprocessed joint tensor | |
| """ | |
| # Validate and clamp joint values | |
| joint_positions = JointConfig.validate_joint_values(joint_positions) | |
| # Convert to tensor | |
| joint_tensor = torch.from_numpy(joint_positions).float().to(self.device) | |
| # Normalize if we have dataset statistics | |
| if self.stats and hasattr(self.stats, "joint_stats"): | |
| joint_tensor = self._normalize_joints(joint_tensor) | |
| return joint_tensor | |
| def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor: | |
| """Normalize joint values using dataset statistics.""" | |
| # This would use the actual dataset statistics | |
| # For now, we assume joints are already normalized | |
| return joint_tensor | |
| def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]: | |
| """ | |
| Convert action array to joint commands with names. | |
| Args: | |
| action: Array of joint actions in standard order | |
| Returns: | |
| List of joint command dictionaries | |
| """ | |
| # Validate action values | |
| action = JointConfig.validate_joint_values(action) | |
| # Create commands with AI names (always use AI names for output) | |
| return JointConfig.create_joint_commands(action) | |
| def reset(self): | |
| """Reset the inference engine state.""" | |
| self.last_images = {} | |
| self.last_joint_positions = None | |
| # Clear any model-specific state | |
| if hasattr(self.policy, "reset"): | |
| self.policy.reset() | |
| async def cleanup(self): | |
| """Clean up resources.""" | |
| if self.policy: | |
| del self.policy | |
| self.policy = None | |
| # Clear GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self.is_loaded = False | |
| logger.info(f"Cleaned up inference engine for {self.policy_path}") | |
| def __del__(self): | |
| """Destructor to ensure cleanup.""" | |
| if hasattr(self, "policy") and self.policy: | |
| asyncio.create_task(self.cleanup()) | |