b2bomber commited on
Commit
851d856
·
verified ·
1 Parent(s): 83f46b6

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +26 -0
generate.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from shap_e.diffusion.sample import sample_latents
4
+ from shap_e.diffusion.eval.image_embedder import CLIPImageEmbedder
5
+ from shap_e.models.download import load_model
6
+ from shap_e.util.notebooks import decode_latent_mesh, create_pan_cameras, render_views
7
+
8
+ device = torch.device('cpu') # Spaces use CPU
9
+
10
+ # Load models
11
+ model = load_model('text300M', device=device)
12
+ image_embedder = CLIPImageEmbedder(device=device)
13
+
14
+ def generate_model_from_text(prompt: str, filename="model.glb"):
15
+ batch = [prompt]
16
+ latents = sample_latents(model, batch, guidance_scale=15.0, device=device)
17
+ mesh = decode_latent_mesh(latents[0], device=device)
18
+ mesh.export(filename)
19
+ return filename
20
+
21
+ def generate_model_from_image(image: Image.Image, filename="model.glb"):
22
+ image_embed = image_embedder.embed_image(image)
23
+ latents = sample_latents(model, [image_embed], guidance_scale=15.0, device=device, is_image=True)
24
+ mesh = decode_latent_mesh(latents[0], device=device)
25
+ mesh.export(filename)
26
+ return filename