b2bomber commited on
Commit
2d2f0b3
·
verified ·
1 Parent(s): 2fb3e38

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +77 -12
generate.py CHANGED
@@ -1,26 +1,91 @@
1
  import torch
2
  from PIL import Image
3
  from shap_e.diffusion.sample import sample_latents
4
- from shap_e.diffusion.eval.image_embedder import CLIPImageEmbedder
5
- from shap_e.models.download import load_model
6
  from shap_e.util.notebooks import decode_latent_mesh, create_pan_cameras, render_views
 
 
 
7
 
8
- device = torch.device('cpu') # Spaces use CPU
 
9
 
10
  # Load models
11
- model = load_model('text300M', device=device)
12
- image_embedder = CLIPImageEmbedder(device=device)
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_model_from_text(prompt: str, filename="model.glb"):
15
- batch = [prompt]
16
- latents = sample_latents(model, batch, guidance_scale=15.0, device=device)
17
- mesh = decode_latent_mesh(latents[0], device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  mesh.export(filename)
 
19
  return filename
20
 
21
  def generate_model_from_image(image: Image.Image, filename="model.glb"):
22
- image_embed = image_embedder.embed_image(image)
23
- latents = sample_latents(model, [image_embed], guidance_scale=15.0, device=device, is_image=True)
24
- mesh = decode_latent_mesh(latents[0], device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  mesh.export(filename)
26
- return filename
 
 
1
  import torch
2
  from PIL import Image
3
  from shap_e.diffusion.sample import sample_latents
4
+ from shap_e.models.download import load_model # Used for loading 'transmitter', 'text300M', 'image300M'
 
5
  from shap_e.util.notebooks import decode_latent_mesh, create_pan_cameras, render_views
6
+ # You might also need this if you are handling raw diffusion setup
7
+ # from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
8
+ # from shap_e.models.configs import model_from_config # for more explicit config loading
9
 
10
+ # Determine device (CPU for Spaces without GPU, or CUDA if available)
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Load models
14
+ # The 'transmitter' model includes the image encoder (CLIP) and is used for image input.
15
+ # The 'image300M' model is the diffusion model for image-conditioned generation.
16
+ # The 'text300M' model is the diffusion model for text-conditioned generation.
17
+ # Load these once globally to avoid repeated loading.
18
+ print(f"Loading models on {device}...")
19
+ xm = load_model('transmitter', device=device) # Contains the image encoder (e.g., CLIP)
20
+ model_text = load_model('text300M', device=device)
21
+ model_image = load_model('image300M', device=device) # This is the diffusion model for image input
22
+
23
+ # Diffusion configuration is often loaded automatically or integrated into the pipeline
24
+ # diffusion = diffusion_from_config(load_config('diffusion')) # If you need explicit diffusion object
25
+
26
+ print("Models loaded successfully.")
27
+
28
 
29
  def generate_model_from_text(prompt: str, filename="model.glb"):
30
+ print(f"Generating 3D model from text: '{prompt}'")
31
+ # `sample_latents` takes the model directly for text-to-3D
32
+ latents = sample_latents(
33
+ batch_size=1, # Generate one model
34
+ model=model_text,
35
+ diffusion=None, # Diffusion is often integrated or defaulted in sample_latents when using models from load_model
36
+ guidance_scale=15.0,
37
+ model_kwargs=dict(texts=[prompt]),
38
+ progress=True,
39
+ clip_denoised=True,
40
+ use_fp16=True if device.type == 'cuda' else False, # Use FP16 only if GPU
41
+ use_karras=True,
42
+ karras_steps=64,
43
+ sigma_min=1e-3,
44
+ sigma_max=160,
45
+ s_churn=0,
46
+ )
47
+ print("Latents sampled.")
48
+
49
+ mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode
50
  mesh.export(filename)
51
+ print(f"Model saved to {filename}")
52
  return filename
53
 
54
  def generate_model_from_image(image: Image.Image, filename="model.glb"):
55
+ print("Generating 3D model from image...")
56
+
57
+ # For image input, you need to prepare the image.
58
+ # The 'transmitter' (xm) implicitly handles the CLIP embedding when used in sample_latents
59
+ # with `is_image=True` and `model_kwargs=dict(images=[image_tensor])`.
60
+ # First, resize the image to the expected input size (e.g., 256x256) and convert to tensor.
61
+ # Note: shap-e expects a specific image format (likely 3x256x256 float32 tensor in [0,1])
62
+ # You might need torchvision.transforms or similar for robust image preprocessing.
63
+ from torchvision import transforms
64
+ preprocess = transforms.Compose([
65
+ transforms.Resize(256),
66
+ transforms.CenterCrop(256),
67
+ transforms.ToTensor(), # Converts PIL Image to torch.Tensor (C, H, W) in [0, 1]
68
+ ])
69
+ image_tensor = preprocess(image).unsqueeze(0).to(device) # Add batch dimension (1, C, H, W)
70
+
71
+ latents = sample_latents(
72
+ batch_size=1,
73
+ model=model_image, # Use the image-conditioned model
74
+ diffusion=None, # Same as above, often integrated
75
+ guidance_scale=3.0, # Guidance scale often different for image-to-3D
76
+ model_kwargs=dict(images=[image_tensor]), # Pass the tensor
77
+ progress=True,
78
+ clip_denoised=True,
79
+ use_fp16=True if device.type == 'cuda' else False,
80
+ use_karras=True,
81
+ karras_steps=64,
82
+ sigma_min=1e-3,
83
+ sigma_max=160,
84
+ s_churn=0,
85
+ )
86
+ print("Latents sampled.")
87
+
88
+ mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() # Use xm (transmitter) to decode
89
  mesh.export(filename)
90
+ print(f"Model saved to {filename}")
91
+ return filename