text-to-3d / model_generator.py
jainarham's picture
Update model_generator.py
72d268d verified
"""
Advanced 3D Model Generator
Supports materials, complex arrangements, and optimized generation
"""
import trimesh
import numpy as np
from typing import Dict, Any, List, Tuple
import io
import logging
logger = logging.getLogger(__name__)
class AdvancedModelGenerator:
def __init__(self):
self._init_generators()
logger.info("Advanced Model Generator initialized")
def _init_generators(self):
"""Initialize shape generators"""
self.generators = {
"cube": self._gen_cube,
"sphere": self._gen_sphere,
"cylinder": self._gen_cylinder,
"cone": self._gen_cone,
"torus": self._gen_torus,
"pyramid": self._gen_pyramid,
"capsule": self._gen_capsule,
"plane": self._gen_plane,
}
def _hex_to_rgba(self, hex_color: str) -> Tuple[int, int, int, int]:
"""Convert hex to RGBA"""
hex_color = hex_color.lstrip('#')
if len(hex_color) == 6:
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
return (r, g, b, 255)
return (128, 128, 128, 255)
def _gen_cube(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
scale_axes = params.get("scale_axes", {"x": 1, "y": 1, "z": 1})
return trimesh.creation.box(extents=[
1.0 * scale * scale_axes.get("x", 1),
1.0 * scale * scale_axes.get("y", 1),
1.0 * scale * scale_axes.get("z", 1)
])
def _gen_sphere(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
return trimesh.creation.icosphere(subdivisions=3, radius=0.5 * scale)
def _gen_cylinder(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
scale_axes = params.get("scale_axes", {"x": 1, "y": 1, "z": 1})
return trimesh.creation.cylinder(
radius=0.4 * scale * scale_axes.get("x", 1),
height=1.0 * scale * scale_axes.get("y", 1),
sections=32
)
def _gen_cone(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
return trimesh.creation.cone(radius=0.5 * scale, height=1.0 * scale, sections=32)
def _gen_torus(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
major_r = 0.4 * scale
minor_r = 0.15 * scale
u = np.linspace(0, 2 * np.pi, 32)
v = np.linspace(0, 2 * np.pi, 16)
u, v = np.meshgrid(u, v)
u, v = u.flatten(), v.flatten()
x = (major_r + minor_r * np.cos(v)) * np.cos(u)
y = (major_r + minor_r * np.cos(v)) * np.sin(u)
z = minor_r * np.sin(v)
vertices = np.column_stack([x, y, z])
faces = []
rows, cols = 16, 32
for i in range(rows):
for j in range(cols):
p1 = i * cols + j
p2 = i * cols + (j + 1) % cols
p3 = ((i + 1) % rows) * cols + (j + 1) % cols
p4 = ((i + 1) % rows) * cols + j
faces.extend([[p1, p2, p3], [p1, p3, p4]])
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
mesh.fix_normals()
return mesh
def _gen_pyramid(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
vertices = np.array([
[0, 0.5 * scale, 0],
[-0.5 * scale, -0.5 * scale, -0.5 * scale],
[0.5 * scale, -0.5 * scale, -0.5 * scale],
[0.5 * scale, -0.5 * scale, 0.5 * scale],
[-0.5 * scale, -0.5 * scale, 0.5 * scale],
])
faces = np.array([
[0, 1, 2], [0, 2, 3], [0, 3, 4], [0, 4, 1],
[1, 3, 2], [1, 4, 3]
])
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
mesh.fix_normals()
return mesh
def _gen_capsule(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
return trimesh.creation.capsule(radius=0.3 * scale, height=0.6 * scale)
def _gen_plane(self, params: Dict) -> trimesh.Trimesh:
scale = params.get("scale", 1.0)
return trimesh.creation.box(extents=[2.0 * scale, 0.1, 2.0 * scale])
def _apply_color(self, mesh: trimesh.Trimesh, color: str) -> trimesh.Trimesh:
"""Apply color to mesh"""
rgba = self._hex_to_rgba(color)
vertex_colors = np.array([rgba] * len(mesh.vertices), dtype=np.uint8)
mesh.visual = trimesh.visual.ColorVisuals(mesh, vertex_colors=vertex_colors)
return mesh
def _apply_transform(self, mesh: trimesh.Trimesh, position: Dict, rotation: Dict) -> trimesh.Trimesh:
"""Apply position and rotation"""
# Rotation
if rotation:
rx = np.radians(rotation.get("x", 0))
ry = np.radians(rotation.get("y", 0))
rz = np.radians(rotation.get("z", 0))
rot_matrix = trimesh.transformations.euler_matrix(rx, ry, rz)
mesh.apply_transform(rot_matrix)
# Translation
translation = np.array([
position.get("x", 0),
position.get("y", 0),
position.get("z", 0)
])
mesh.apply_translation(translation)
return mesh
def generate(self, model_params: Dict[str, Any]) -> Dict[str, Any]:
"""Generate 3D model from parameters"""
try:
objects = model_params.get("objects", [])
if not objects:
return {"success": False, "error": "No objects to generate"}
meshes = []
for obj in objects:
shape_type = obj.get("type", "cube")
generator = self.generators.get(shape_type, self._gen_cube)
# Create mesh
mesh = generator(obj)
# Apply color
mesh = self._apply_color(mesh, obj.get("color", "#808080"))
# Apply transforms
mesh = self._apply_transform(
mesh,
obj.get("position", {"x": 0, "y": 0, "z": 0}),
obj.get("rotation", {"x": 0, "y": 0, "z": 0})
)
meshes.append(mesh)
# Combine meshes
if len(meshes) == 1:
combined = meshes[0]
else:
combined = trimesh.util.concatenate(meshes)
# Export
glb_buffer = io.BytesIO()
combined.export(glb_buffer, file_type='glb')
glb_data = glb_buffer.getvalue()
obj_buffer = io.BytesIO()
combined.export(obj_buffer, file_type='obj')
obj_data = obj_buffer.getvalue()
logger.info(f"Generated: {len(meshes)} meshes, GLB: {len(glb_data)} bytes")
return {
"success": True,
"glb_data": glb_data,
"obj_data": obj_data,
"vertex_count": len(combined.vertices),
"face_count": len(combined.faces)
}
except Exception as e:
logger.error(f"Generation error: {str(e)}")
return {"success": False, "error": str(e)}
class ModelGenerator:
"""Wrapper for backward compatibility"""
def __init__(self):
self._generator = AdvancedModelGenerator()
def generate(self, model_params: Dict[str, Any]) -> Dict[str, Any]:
return self._generator.generate(model_params)