#!/usr/bin/env python3 """ Example usage script for Pre-trained-v2 models Demonstrates how to load and use the physics-based 3D object deformation models """ import torch import numpy as np import trimesh import json import os from typing import Dict, List, Tuple class PhysicsDeformationModel: """Wrapper class for loading and using the pre-trained deformation models""" def __init__(self, model_dir: str, model_name: str): """ Initialize the model Args: model_dir: Directory containing the model files model_name: Name of the model (e.g., 'base', 'pot') """ self.model_dir = model_dir self.model_name = model_name # Load model files self.encoder_path = os.path.join(model_dir, f"{model_name}-encoder.pt") self.decoder_path = os.path.join(model_dir, f"{model_name}-decoder.pt") self.mesh_path = os.path.join(model_dir, f"{model_name}.obj") # Check if files exist if not all(os.path.exists(path) for path in [self.encoder_path, self.decoder_path, self.mesh_path]): raise FileNotFoundError(f"Model files not found in {model_dir}") # Load reference mesh self.reference_mesh = trimesh.load(self.mesh_path) # Initialize encoder and decoder (you'll need to implement these based on your architecture) self.encoder = self._load_encoder() self.decoder = self._load_decoder() def _load_encoder(self): """Load the encoder model""" # This is a placeholder - implement based on your actual encoder architecture encoder = torch.nn.Sequential( torch.nn.Linear(9, 512), torch.nn.ReLU(), torch.nn.Linear(512, 256), torch.nn.ReLU(), torch.nn.Linear(256, 128), torch.nn.ReLU(), torch.nn.Linear(128, 64) ) # Load pre-trained weights encoder.load_state_dict(torch.load(self.encoder_path, map_location='cpu')) encoder.eval() return encoder def _load_decoder(self): """Load the decoder model""" # This is a placeholder - implement based on your actual decoder architecture decoder = torch.nn.Sequential( torch.nn.Linear(64, 128), torch.nn.ReLU(), torch.nn.Linear(128, 256), torch.nn.ReLU(), torch.nn.Linear(256, 512), torch.nn.ReLU(), torch.nn.Linear(512, 3) ) # Load pre-trained weights decoder.load_state_dict(torch.load(self.decoder_path, map_location='cpu')) decoder.eval() return decoder def prepare_input_conditions(self, impact_point: List[float], velocity: List[float], force: float) -> torch.Tensor: """ Prepare input conditions for the model Args: impact_point: [x, y, z] coordinates of impact point velocity: [vx, vy, vz] velocity vector force: Impact force magnitude Returns: Input tensor for the encoder """ # Normalize and combine inputs input_data = np.array(impact_point + velocity + [force], dtype=np.float32) # Normalize to match training data distribution # You may need to adjust these normalization parameters based on your training data input_data[:3] = (input_data[:3] - np.array([0.0, 0.5, 0.0])) / 0.5 # Normalize impact point input_data[3:6] = input_data[3:6] / 10.0 # Normalize velocity input_data[6] = input_data[6] / 1000.0 # Normalize force return torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) def predict_deformation(self, impact_point: List[float], velocity: List[float], force: float) -> np.ndarray: """ Predict object deformation given impact conditions Args: impact_point: [x, y, z] coordinates of impact point velocity: [vx, vy, vz] velocity vector force: Impact force magnitude Returns: Deformed vertex positions """ # Prepare input input_tensor = self.prepare_input_conditions(impact_point, velocity, force) # Run inference with torch.no_grad(): latent = self.encoder(input_tensor) deformation = self.decoder(latent) # Reshape to vertex positions vertices = deformation.squeeze().numpy().reshape(-1, 3) return vertices def apply_deformation_to_mesh(self, impact_point: List[float], velocity: List[float], force: float) -> trimesh.Trimesh: """ Apply deformation to the reference mesh Args: impact_point: [x, y, z] coordinates of impact point velocity: [vx, vy, vz] velocity vector force: Impact force magnitude Returns: Deformed mesh """ # Get deformed vertices deformed_vertices = self.predict_deformation(impact_point, velocity, force) # Create new mesh with deformed vertices deformed_mesh = self.reference_mesh.copy() deformed_mesh.vertices = deformed_vertices return deformed_mesh def save_deformed_mesh(self, output_path: str, impact_point: List[float], velocity: List[float], force: float): """ Save deformed mesh to file Args: output_path: Path to save the deformed mesh impact_point: [x, y, z] coordinates of impact point velocity: [vx, vy, vz] velocity vector force: Impact force magnitude """ deformed_mesh = self.apply_deformation_to_mesh(impact_point, velocity, force) deformed_mesh.export(output_path) def main(): """Example usage of the PhysicsDeformationModel""" # Example parameters model_dir = "base" # Change to your model directory model_name = "base" # Impact conditions impact_point = [0.1, 0.8, 0.1] # [x, y, z] velocity = [0.0, -5.0, 0.0] # [vx, vy, vz] force = 500.0 # Force magnitude try: # Initialize model print(f"Loading {model_name} model...") model = PhysicsDeformationModel(model_dir, model_name) print("Model loaded successfully!") # Predict deformation print("Predicting deformation...") deformed_vertices = model.predict_deformation(impact_point, velocity, force) print(f"Deformation predicted. Output shape: {deformed_vertices.shape}") # Save deformed mesh output_path = f"deformed_{model_name}.obj" model.save_deformed_mesh(output_path, impact_point, velocity, force) print(f"Deformed mesh saved to: {output_path}") # Display some statistics original_vertices = model.reference_mesh.vertices deformation_magnitude = np.linalg.norm(deformed_vertices - original_vertices, axis=1) print(f"Average deformation magnitude: {np.mean(deformation_magnitude):.4f}") print(f"Maximum deformation magnitude: {np.max(deformation_magnitude):.4f}") except Exception as e: print(f"Error: {e}") print("Make sure you have the correct model files and dependencies installed.") if __name__ == "__main__": main()