File size: 4,125 Bytes
851d856
 
 
2d2f0b3
851d856
2d2f0b3
 
 
851d856
2d2f0b3
 
851d856
 
2d2f0b3
 
 
 
 
 
 
 
 
 
 
 
 
 
851d856
 
2d2f0b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851d856
2d2f0b3
851d856
 
 
2d2f0b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851d856
2d2f0b3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from PIL import Image
from shap_e.diffusion.sample import sample_latents
from shap_e.models.download import load_model # Used for loading 'transmitter', 'text300M', 'image300M'
from shap_e.util.notebooks import decode_latent_mesh, create_pan_cameras, render_views
# You might also need this if you are handling raw diffusion setup
# from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
# from shap_e.models.configs import model_from_config # for more explicit config loading

# Determine device (CPU for Spaces without GPU, or CUDA if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load models
# The 'transmitter' model includes the image encoder (CLIP) and is used for image input.
# The 'image300M' model is the diffusion model for image-conditioned generation.
# The 'text300M' model is the diffusion model for text-conditioned generation.
# Load these once globally to avoid repeated loading.
print(f"Loading models on {device}...")
xm = load_model('transmitter', device=device) # Contains the image encoder (e.g., CLIP)
model_text = load_model('text300M', device=device)
model_image = load_model('image300M', device=device) # This is the diffusion model for image input

# Diffusion configuration is often loaded automatically or integrated into the pipeline
# diffusion = diffusion_from_config(load_config('diffusion')) # If you need explicit diffusion object

print("Models loaded successfully.")


def generate_model_from_text(prompt: str, filename="model.glb"):
    print(f"Generating 3D model from text: '{prompt}'")
    # `sample_latents` takes the model directly for text-to-3D
    latents = sample_latents(
        batch_size=1, # Generate one model
        model=model_text,
        diffusion=None, # Diffusion is often integrated or defaulted in sample_latents when using models from load_model
        guidance_scale=15.0,
        model_kwargs=dict(texts=[prompt]),
        progress=True,
        clip_denoised=True,
        use_fp16=True if device.type == 'cuda' else False, # Use FP16 only if GPU
        use_karras=True,
        karras_steps=64,
        sigma_min=1e-3,
        sigma_max=160,
        s_churn=0,
    )
    print("Latents sampled.")

    mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode
    mesh.export(filename)
    print(f"Model saved to {filename}")
    return filename

def generate_model_from_image(image: Image.Image, filename="model.glb"):
    print("Generating 3D model from image...")

    # For image input, you need to prepare the image.
    # The 'transmitter' (xm) implicitly handles the CLIP embedding when used in sample_latents
    # with `is_image=True` and `model_kwargs=dict(images=[image_tensor])`.
    # First, resize the image to the expected input size (e.g., 256x256) and convert to tensor.
    # Note: shap-e expects a specific image format (likely 3x256x256 float32 tensor in [0,1])
    # You might need torchvision.transforms or similar for robust image preprocessing.
    from torchvision import transforms
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(), # Converts PIL Image to torch.Tensor (C, H, W) in [0, 1]
    ])
    image_tensor = preprocess(image).unsqueeze(0).to(device) # Add batch dimension (1, C, H, W)

    latents = sample_latents(
        batch_size=1,
        model=model_image, # Use the image-conditioned model
        diffusion=None, # Same as above, often integrated
        guidance_scale=3.0, # Guidance scale often different for image-to-3D
        model_kwargs=dict(images=[image_tensor]), # Pass the tensor
        progress=True,
        clip_denoised=True,
        use_fp16=True if device.type == 'cuda' else False,
        use_karras=True,
        karras_steps=64,
        sigma_min=1e-3,
        sigma_max=160,
        s_churn=0,
    )
    print("Latents sampled.")

    mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode
    mesh.export(filename)
    print(f"Model saved to {filename}")
    return filename