File size: 7,738 Bytes
b152a33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | #!/usr/bin/env python3
"""
Example usage script for Pre-trained-v2 models
Demonstrates how to load and use the physics-based 3D object deformation models
"""
import torch
import numpy as np
import trimesh
import json
import os
from typing import Dict, List, Tuple
class PhysicsDeformationModel:
"""Wrapper class for loading and using the pre-trained deformation models"""
def __init__(self, model_dir: str, model_name: str):
"""
Initialize the model
Args:
model_dir: Directory containing the model files
model_name: Name of the model (e.g., 'base', 'pot')
"""
self.model_dir = model_dir
self.model_name = model_name
# Load model files
self.encoder_path = os.path.join(model_dir, f"{model_name}-encoder.pt")
self.decoder_path = os.path.join(model_dir, f"{model_name}-decoder.pt")
self.mesh_path = os.path.join(model_dir, f"{model_name}.obj")
# Check if files exist
if not all(os.path.exists(path) for path in [self.encoder_path, self.decoder_path, self.mesh_path]):
raise FileNotFoundError(f"Model files not found in {model_dir}")
# Load reference mesh
self.reference_mesh = trimesh.load(self.mesh_path)
# Initialize encoder and decoder (you'll need to implement these based on your architecture)
self.encoder = self._load_encoder()
self.decoder = self._load_decoder()
def _load_encoder(self):
"""Load the encoder model"""
# This is a placeholder - implement based on your actual encoder architecture
encoder = torch.nn.Sequential(
torch.nn.Linear(9, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64)
)
# Load pre-trained weights
encoder.load_state_dict(torch.load(self.encoder_path, map_location='cpu'))
encoder.eval()
return encoder
def _load_decoder(self):
"""Load the decoder model"""
# This is a placeholder - implement based on your actual decoder architecture
decoder = torch.nn.Sequential(
torch.nn.Linear(64, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 3)
)
# Load pre-trained weights
decoder.load_state_dict(torch.load(self.decoder_path, map_location='cpu'))
decoder.eval()
return decoder
def prepare_input_conditions(self, impact_point: List[float],
velocity: List[float],
force: float) -> torch.Tensor:
"""
Prepare input conditions for the model
Args:
impact_point: [x, y, z] coordinates of impact point
velocity: [vx, vy, vz] velocity vector
force: Impact force magnitude
Returns:
Input tensor for the encoder
"""
# Normalize and combine inputs
input_data = np.array(impact_point + velocity + [force], dtype=np.float32)
# Normalize to match training data distribution
# You may need to adjust these normalization parameters based on your training data
input_data[:3] = (input_data[:3] - np.array([0.0, 0.5, 0.0])) / 0.5 # Normalize impact point
input_data[3:6] = input_data[3:6] / 10.0 # Normalize velocity
input_data[6] = input_data[6] / 1000.0 # Normalize force
return torch.tensor(input_data, dtype=torch.float32).unsqueeze(0)
def predict_deformation(self, impact_point: List[float],
velocity: List[float],
force: float) -> np.ndarray:
"""
Predict object deformation given impact conditions
Args:
impact_point: [x, y, z] coordinates of impact point
velocity: [vx, vy, vz] velocity vector
force: Impact force magnitude
Returns:
Deformed vertex positions
"""
# Prepare input
input_tensor = self.prepare_input_conditions(impact_point, velocity, force)
# Run inference
with torch.no_grad():
latent = self.encoder(input_tensor)
deformation = self.decoder(latent)
# Reshape to vertex positions
vertices = deformation.squeeze().numpy().reshape(-1, 3)
return vertices
def apply_deformation_to_mesh(self, impact_point: List[float],
velocity: List[float],
force: float) -> trimesh.Trimesh:
"""
Apply deformation to the reference mesh
Args:
impact_point: [x, y, z] coordinates of impact point
velocity: [vx, vy, vz] velocity vector
force: Impact force magnitude
Returns:
Deformed mesh
"""
# Get deformed vertices
deformed_vertices = self.predict_deformation(impact_point, velocity, force)
# Create new mesh with deformed vertices
deformed_mesh = self.reference_mesh.copy()
deformed_mesh.vertices = deformed_vertices
return deformed_mesh
def save_deformed_mesh(self, output_path: str, impact_point: List[float],
velocity: List[float], force: float):
"""
Save deformed mesh to file
Args:
output_path: Path to save the deformed mesh
impact_point: [x, y, z] coordinates of impact point
velocity: [vx, vy, vz] velocity vector
force: Impact force magnitude
"""
deformed_mesh = self.apply_deformation_to_mesh(impact_point, velocity, force)
deformed_mesh.export(output_path)
def main():
"""Example usage of the PhysicsDeformationModel"""
# Example parameters
model_dir = "base" # Change to your model directory
model_name = "base"
# Impact conditions
impact_point = [0.1, 0.8, 0.1] # [x, y, z]
velocity = [0.0, -5.0, 0.0] # [vx, vy, vz]
force = 500.0 # Force magnitude
try:
# Initialize model
print(f"Loading {model_name} model...")
model = PhysicsDeformationModel(model_dir, model_name)
print("Model loaded successfully!")
# Predict deformation
print("Predicting deformation...")
deformed_vertices = model.predict_deformation(impact_point, velocity, force)
print(f"Deformation predicted. Output shape: {deformed_vertices.shape}")
# Save deformed mesh
output_path = f"deformed_{model_name}.obj"
model.save_deformed_mesh(output_path, impact_point, velocity, force)
print(f"Deformed mesh saved to: {output_path}")
# Display some statistics
original_vertices = model.reference_mesh.vertices
deformation_magnitude = np.linalg.norm(deformed_vertices - original_vertices, axis=1)
print(f"Average deformation magnitude: {np.mean(deformation_magnitude):.4f}")
print(f"Maximum deformation magnitude: {np.max(deformation_magnitude):.4f}")
except Exception as e:
print(f"Error: {e}")
print("Make sure you have the correct model files and dependencies installed.")
if __name__ == "__main__":
main()
|