Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.models.download import load_model # Used for loading 'transmitter', 'text300M', 'image300M' | |
| from shap_e.util.notebooks import decode_latent_mesh, create_pan_cameras, render_views | |
| # You might also need this if you are handling raw diffusion setup | |
| # from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| # from shap_e.models.configs import model_from_config # for more explicit config loading | |
| # Determine device (CPU for Spaces without GPU, or CUDA if available) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load models | |
| # The 'transmitter' model includes the image encoder (CLIP) and is used for image input. | |
| # The 'image300M' model is the diffusion model for image-conditioned generation. | |
| # The 'text300M' model is the diffusion model for text-conditioned generation. | |
| # Load these once globally to avoid repeated loading. | |
| print(f"Loading models on {device}...") | |
| xm = load_model('transmitter', device=device) # Contains the image encoder (e.g., CLIP) | |
| model_text = load_model('text300M', device=device) | |
| model_image = load_model('image300M', device=device) # This is the diffusion model for image input | |
| # Diffusion configuration is often loaded automatically or integrated into the pipeline | |
| # diffusion = diffusion_from_config(load_config('diffusion')) # If you need explicit diffusion object | |
| print("Models loaded successfully.") | |
| def generate_model_from_text(prompt: str, filename="model.glb"): | |
| print(f"Generating 3D model from text: '{prompt}'") | |
| # `sample_latents` takes the model directly for text-to-3D | |
| latents = sample_latents( | |
| batch_size=1, # Generate one model | |
| model=model_text, | |
| diffusion=None, # Diffusion is often integrated or defaulted in sample_latents when using models from load_model | |
| guidance_scale=15.0, | |
| model_kwargs=dict(texts=[prompt]), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True if device.type == 'cuda' else False, # Use FP16 only if GPU | |
| use_karras=True, | |
| karras_steps=64, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| print("Latents sampled.") | |
| mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode | |
| mesh.export(filename) | |
| print(f"Model saved to {filename}") | |
| return filename | |
| def generate_model_from_image(image: Image.Image, filename="model.glb"): | |
| print("Generating 3D model from image...") | |
| # For image input, you need to prepare the image. | |
| # The 'transmitter' (xm) implicitly handles the CLIP embedding when used in sample_latents | |
| # with `is_image=True` and `model_kwargs=dict(images=[image_tensor])`. | |
| # First, resize the image to the expected input size (e.g., 256x256) and convert to tensor. | |
| # Note: shap-e expects a specific image format (likely 3x256x256 float32 tensor in [0,1]) | |
| # You might need torchvision.transforms or similar for robust image preprocessing. | |
| from torchvision import transforms | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(256), | |
| transforms.ToTensor(), # Converts PIL Image to torch.Tensor (C, H, W) in [0, 1] | |
| ]) | |
| image_tensor = preprocess(image).unsqueeze(0).to(device) # Add batch dimension (1, C, H, W) | |
| latents = sample_latents( | |
| batch_size=1, | |
| model=model_image, # Use the image-conditioned model | |
| diffusion=None, # Same as above, often integrated | |
| guidance_scale=3.0, # Guidance scale often different for image-to-3D | |
| model_kwargs=dict(images=[image_tensor]), # Pass the tensor | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True if device.type == 'cuda' else False, | |
| use_karras=True, | |
| karras_steps=64, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| print("Latents sampled.") | |
| mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode | |
| mesh.export(filename) | |
| print(f"Model saved to {filename}") | |
| return filename |