vitruvius / app /main.py
andreagalle
does it fix the "AttributeError: 'Response' object has no attribute 'get'" issue ?
2520a67
import torch
from fastapi import FastAPI, Response, Form
from fastapi.responses import HTMLResponse
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, decode_latent_mesh
import io, os, zipfile
from fastapi.templating import Jinja2Templates
app = FastAPI()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
batch_size = 4
guidance_scale = 15.0
prompt = "a shark"
def generate_images_and_meshes():
latents = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=64,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
render_mode = 'nerf' # you can change this to 'stf'
size = 64 # this is the size of the renders; higher values take longer to render.
cameras = create_pan_cameras(size, device)
# images = []
# for i, latent in enumerate(latents):
# images.append(decode_latent_images(xm, latent, cameras, rendering_mode=render_mode))
for i, latent in enumerate(latents):
t = decode_latent_mesh(xm, latent).tri_mesh()
with open(f'example_mesh_{i}.ply', 'wb') as f:
t.write_ply(f)
with open(f'example_mesh_{i}.obj', 'w') as f:
t.write_obj(f)
# return images
return os.getcwd()
# Function to list all .obj files in the current directory
def list_obj_files():
obj_files = []
for file_name in os.listdir():
if file_name.endswith(".obj"):
obj_files.append(file_name)
return obj_files
# # HTML form
# html_form = """
# <!DOCTYPE html>
# <html>
# <head>
# <title>Simple Input Form</title>
# </head>
# <body>
# <h1>Simple Input Form</h1>
# <form method="post" action="/submit">
# <label for="input_text">Enter something:</label>
# <input type="text" id="input_text" name="user_input" required>
# <br><br>
# <input type="submit" value="Submit">
# </form>
# </body>
# </html>
# """
# Determine the absolute path of the directory containing main.py
current_directory = os.path.dirname(os.path.abspath(__file__))
templates_directory = os.path.join(current_directory, "templates")
# Use the determined path for Jinja2Templates
templates = Jinja2Templates(directory=templates_directory)
@app.get("/", response_class=HTMLResponse)
async def get_input_form(request: HTMLResponse):
return templates.TemplateResponse("input_form.html", {"request": request})
@app.post("/submit", response_class=HTMLResponse)
async def submit_input(user_input: str, request: HTMLResponse):
return templates.TemplateResponse("input_form.html", {"request": request, "submitted_text": user_input})
# @app.get("/")
# def read_root():
# return {"message": "Hello, Dariowsky!"}
@app.get("/generate/me")
async def generate_me():
return {"output": "this is you"}
@app.get("/generate/{prompt}")
async def generate(prompt: str):
return {"output": f"this is {prompt}"}
@app.get("/make")
def make():
images = generate_images_and_meshes()
return {"status": f"Images generated within {images}"}
@app.get("/download")
def download():
# List all .obj files in the current directory
obj_files = list_obj_files()
# images = generate_images_and_meshes()
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
# Add .obj files to the zip archive
for obj_file in obj_files:
zf.write(obj_file)
zip_buffer.seek(0)
content = zip_buffer.getvalue()
response = Response(content, media_type="application/zip")
response.headers["Content-Disposition"] = "attachment; filename=images.zip"
return response