| | |
| | """ |
| | Example usage script for Pre-trained-v2 models |
| | Demonstrates how to load and use the physics-based 3D object deformation models |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | import trimesh |
| | import json |
| | import os |
| | from typing import Dict, List, Tuple |
| |
|
| | class PhysicsDeformationModel: |
| | """Wrapper class for loading and using the pre-trained deformation models""" |
| | |
| | def __init__(self, model_dir: str, model_name: str): |
| | """ |
| | Initialize the model |
| | |
| | Args: |
| | model_dir: Directory containing the model files |
| | model_name: Name of the model (e.g., 'base', 'pot') |
| | """ |
| | self.model_dir = model_dir |
| | self.model_name = model_name |
| | |
| | |
| | self.encoder_path = os.path.join(model_dir, f"{model_name}-encoder.pt") |
| | self.decoder_path = os.path.join(model_dir, f"{model_name}-decoder.pt") |
| | self.mesh_path = os.path.join(model_dir, f"{model_name}.obj") |
| | |
| | |
| | if not all(os.path.exists(path) for path in [self.encoder_path, self.decoder_path, self.mesh_path]): |
| | raise FileNotFoundError(f"Model files not found in {model_dir}") |
| | |
| | |
| | self.reference_mesh = trimesh.load(self.mesh_path) |
| | |
| | |
| | self.encoder = self._load_encoder() |
| | self.decoder = self._load_decoder() |
| | |
| | def _load_encoder(self): |
| | """Load the encoder model""" |
| | |
| | encoder = torch.nn.Sequential( |
| | torch.nn.Linear(9, 512), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(512, 256), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(256, 128), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(128, 64) |
| | ) |
| | |
| | |
| | encoder.load_state_dict(torch.load(self.encoder_path, map_location='cpu')) |
| | encoder.eval() |
| | return encoder |
| | |
| | def _load_decoder(self): |
| | """Load the decoder model""" |
| | |
| | decoder = torch.nn.Sequential( |
| | torch.nn.Linear(64, 128), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(128, 256), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(256, 512), |
| | torch.nn.ReLU(), |
| | torch.nn.Linear(512, 3) |
| | ) |
| | |
| | |
| | decoder.load_state_dict(torch.load(self.decoder_path, map_location='cpu')) |
| | decoder.eval() |
| | return decoder |
| | |
| | def prepare_input_conditions(self, impact_point: List[float], |
| | velocity: List[float], |
| | force: float) -> torch.Tensor: |
| | """ |
| | Prepare input conditions for the model |
| | |
| | Args: |
| | impact_point: [x, y, z] coordinates of impact point |
| | velocity: [vx, vy, vz] velocity vector |
| | force: Impact force magnitude |
| | |
| | Returns: |
| | Input tensor for the encoder |
| | """ |
| | |
| | input_data = np.array(impact_point + velocity + [force], dtype=np.float32) |
| | |
| | |
| | |
| | input_data[:3] = (input_data[:3] - np.array([0.0, 0.5, 0.0])) / 0.5 |
| | input_data[3:6] = input_data[3:6] / 10.0 |
| | input_data[6] = input_data[6] / 1000.0 |
| | |
| | return torch.tensor(input_data, dtype=torch.float32).unsqueeze(0) |
| | |
| | def predict_deformation(self, impact_point: List[float], |
| | velocity: List[float], |
| | force: float) -> np.ndarray: |
| | """ |
| | Predict object deformation given impact conditions |
| | |
| | Args: |
| | impact_point: [x, y, z] coordinates of impact point |
| | velocity: [vx, vy, vz] velocity vector |
| | force: Impact force magnitude |
| | |
| | Returns: |
| | Deformed vertex positions |
| | """ |
| | |
| | input_tensor = self.prepare_input_conditions(impact_point, velocity, force) |
| | |
| | |
| | with torch.no_grad(): |
| | latent = self.encoder(input_tensor) |
| | deformation = self.decoder(latent) |
| | |
| | |
| | vertices = deformation.squeeze().numpy().reshape(-1, 3) |
| | |
| | return vertices |
| | |
| | def apply_deformation_to_mesh(self, impact_point: List[float], |
| | velocity: List[float], |
| | force: float) -> trimesh.Trimesh: |
| | """ |
| | Apply deformation to the reference mesh |
| | |
| | Args: |
| | impact_point: [x, y, z] coordinates of impact point |
| | velocity: [vx, vy, vz] velocity vector |
| | force: Impact force magnitude |
| | |
| | Returns: |
| | Deformed mesh |
| | """ |
| | |
| | deformed_vertices = self.predict_deformation(impact_point, velocity, force) |
| | |
| | |
| | deformed_mesh = self.reference_mesh.copy() |
| | deformed_mesh.vertices = deformed_vertices |
| | |
| | return deformed_mesh |
| | |
| | def save_deformed_mesh(self, output_path: str, impact_point: List[float], |
| | velocity: List[float], force: float): |
| | """ |
| | Save deformed mesh to file |
| | |
| | Args: |
| | output_path: Path to save the deformed mesh |
| | impact_point: [x, y, z] coordinates of impact point |
| | velocity: [vx, vy, vz] velocity vector |
| | force: Impact force magnitude |
| | """ |
| | deformed_mesh = self.apply_deformation_to_mesh(impact_point, velocity, force) |
| | deformed_mesh.export(output_path) |
| |
|
| | def main(): |
| | """Example usage of the PhysicsDeformationModel""" |
| | |
| | |
| | model_dir = "base" |
| | model_name = "base" |
| | |
| | |
| | impact_point = [0.1, 0.8, 0.1] |
| | velocity = [0.0, -5.0, 0.0] |
| | force = 500.0 |
| | |
| | try: |
| | |
| | print(f"Loading {model_name} model...") |
| | model = PhysicsDeformationModel(model_dir, model_name) |
| | print("Model loaded successfully!") |
| | |
| | |
| | print("Predicting deformation...") |
| | deformed_vertices = model.predict_deformation(impact_point, velocity, force) |
| | print(f"Deformation predicted. Output shape: {deformed_vertices.shape}") |
| | |
| | |
| | output_path = f"deformed_{model_name}.obj" |
| | model.save_deformed_mesh(output_path, impact_point, velocity, force) |
| | print(f"Deformed mesh saved to: {output_path}") |
| | |
| | |
| | original_vertices = model.reference_mesh.vertices |
| | deformation_magnitude = np.linalg.norm(deformed_vertices - original_vertices, axis=1) |
| | print(f"Average deformation magnitude: {np.mean(deformation_magnitude):.4f}") |
| | print(f"Maximum deformation magnitude: {np.max(deformation_magnitude):.4f}") |
| | |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | print("Make sure you have the correct model files and dependencies installed.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|