File size: 7,738 Bytes
b152a33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
Example usage script for Pre-trained-v2 models
Demonstrates how to load and use the physics-based 3D object deformation models
"""

import torch
import numpy as np
import trimesh
import json
import os
from typing import Dict, List, Tuple

class PhysicsDeformationModel:
    """Wrapper class for loading and using the pre-trained deformation models"""
    
    def __init__(self, model_dir: str, model_name: str):
        """
        Initialize the model
        
        Args:
            model_dir: Directory containing the model files
            model_name: Name of the model (e.g., 'base', 'pot')
        """
        self.model_dir = model_dir
        self.model_name = model_name
        
        # Load model files
        self.encoder_path = os.path.join(model_dir, f"{model_name}-encoder.pt")
        self.decoder_path = os.path.join(model_dir, f"{model_name}-decoder.pt")
        self.mesh_path = os.path.join(model_dir, f"{model_name}.obj")
        
        # Check if files exist
        if not all(os.path.exists(path) for path in [self.encoder_path, self.decoder_path, self.mesh_path]):
            raise FileNotFoundError(f"Model files not found in {model_dir}")
        
        # Load reference mesh
        self.reference_mesh = trimesh.load(self.mesh_path)
        
        # Initialize encoder and decoder (you'll need to implement these based on your architecture)
        self.encoder = self._load_encoder()
        self.decoder = self._load_decoder()
        
    def _load_encoder(self):
        """Load the encoder model"""
        # This is a placeholder - implement based on your actual encoder architecture
        encoder = torch.nn.Sequential(
            torch.nn.Linear(9, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64)
        )
        
        # Load pre-trained weights
        encoder.load_state_dict(torch.load(self.encoder_path, map_location='cpu'))
        encoder.eval()
        return encoder
    
    def _load_decoder(self):
        """Load the decoder model"""
        # This is a placeholder - implement based on your actual decoder architecture
        decoder = torch.nn.Sequential(
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 3)
        )
        
        # Load pre-trained weights
        decoder.load_state_dict(torch.load(self.decoder_path, map_location='cpu'))
        decoder.eval()
        return decoder
    
    def prepare_input_conditions(self, impact_point: List[float], 
                               velocity: List[float], 
                               force: float) -> torch.Tensor:
        """
        Prepare input conditions for the model
        
        Args:
            impact_point: [x, y, z] coordinates of impact point
            velocity: [vx, vy, vz] velocity vector
            force: Impact force magnitude
            
        Returns:
            Input tensor for the encoder
        """
        # Normalize and combine inputs
        input_data = np.array(impact_point + velocity + [force], dtype=np.float32)
        
        # Normalize to match training data distribution
        # You may need to adjust these normalization parameters based on your training data
        input_data[:3] = (input_data[:3] - np.array([0.0, 0.5, 0.0])) / 0.5  # Normalize impact point
        input_data[3:6] = input_data[3:6] / 10.0  # Normalize velocity
        input_data[6] = input_data[6] / 1000.0  # Normalize force
        
        return torch.tensor(input_data, dtype=torch.float32).unsqueeze(0)
    
    def predict_deformation(self, impact_point: List[float], 
                          velocity: List[float], 
                          force: float) -> np.ndarray:
        """
        Predict object deformation given impact conditions
        
        Args:
            impact_point: [x, y, z] coordinates of impact point
            velocity: [vx, vy, vz] velocity vector
            force: Impact force magnitude
            
        Returns:
            Deformed vertex positions
        """
        # Prepare input
        input_tensor = self.prepare_input_conditions(impact_point, velocity, force)
        
        # Run inference
        with torch.no_grad():
            latent = self.encoder(input_tensor)
            deformation = self.decoder(latent)
        
        # Reshape to vertex positions
        vertices = deformation.squeeze().numpy().reshape(-1, 3)
        
        return vertices
    
    def apply_deformation_to_mesh(self, impact_point: List[float], 
                                velocity: List[float], 
                                force: float) -> trimesh.Trimesh:
        """
        Apply deformation to the reference mesh
        
        Args:
            impact_point: [x, y, z] coordinates of impact point
            velocity: [vx, vy, vz] velocity vector
            force: Impact force magnitude
            
        Returns:
            Deformed mesh
        """
        # Get deformed vertices
        deformed_vertices = self.predict_deformation(impact_point, velocity, force)
        
        # Create new mesh with deformed vertices
        deformed_mesh = self.reference_mesh.copy()
        deformed_mesh.vertices = deformed_vertices
        
        return deformed_mesh
    
    def save_deformed_mesh(self, output_path: str, impact_point: List[float], 
                          velocity: List[float], force: float):
        """
        Save deformed mesh to file
        
        Args:
            output_path: Path to save the deformed mesh
            impact_point: [x, y, z] coordinates of impact point
            velocity: [vx, vy, vz] velocity vector
            force: Impact force magnitude
        """
        deformed_mesh = self.apply_deformation_to_mesh(impact_point, velocity, force)
        deformed_mesh.export(output_path)

def main():
    """Example usage of the PhysicsDeformationModel"""
    
    # Example parameters
    model_dir = "base"  # Change to your model directory
    model_name = "base"
    
    # Impact conditions
    impact_point = [0.1, 0.8, 0.1]  # [x, y, z]
    velocity = [0.0, -5.0, 0.0]     # [vx, vy, vz]
    force = 500.0                   # Force magnitude
    
    try:
        # Initialize model
        print(f"Loading {model_name} model...")
        model = PhysicsDeformationModel(model_dir, model_name)
        print("Model loaded successfully!")
        
        # Predict deformation
        print("Predicting deformation...")
        deformed_vertices = model.predict_deformation(impact_point, velocity, force)
        print(f"Deformation predicted. Output shape: {deformed_vertices.shape}")
        
        # Save deformed mesh
        output_path = f"deformed_{model_name}.obj"
        model.save_deformed_mesh(output_path, impact_point, velocity, force)
        print(f"Deformed mesh saved to: {output_path}")
        
        # Display some statistics
        original_vertices = model.reference_mesh.vertices
        deformation_magnitude = np.linalg.norm(deformed_vertices - original_vertices, axis=1)
        print(f"Average deformation magnitude: {np.mean(deformation_magnitude):.4f}")
        print(f"Maximum deformation magnitude: {np.max(deformation_magnitude):.4f}")
        
    except Exception as e:
        print(f"Error: {e}")
        print("Make sure you have the correct model files and dependencies installed.")

if __name__ == "__main__":
    main()