guigonzalez's picture
Upload app.py
693c0f5 verified
"""
TripoSG 3D Generator - Hugging Face Spaces
Gera modelos 3D (GLB/STL) a partir de imagens usando TripoSG.
"""
import os
import sys
import subprocess
import tempfile
from pathlib import Path
# Clonar TripoSG se não existir
TRIPOSG_REPO = Path("TripoSG")
if not TRIPOSG_REPO.exists():
print("Clonando repositório TripoSG...")
subprocess.run([
"git", "clone", "--depth", "1",
"https://github.com/VAST-AI-Research/TripoSG.git",
str(TRIPOSG_REPO)
], check=True)
print("TripoSG clonado!")
# Patch: Remover import do diso que não está disponível
# O diso (DiffDMC) é usado para extração de mesh rápida, mas não é essencial
# Usamos use_flash_decoder=False para evitar essa dependência
inference_utils_path = TRIPOSG_REPO / "triposg" / "inference_utils.py"
if inference_utils_path.exists():
print("Aplicando patch para remover dependência diso...")
content = inference_utils_path.read_text()
# Comentar o import do diso
content = content.replace(
"from diso import DiffDMC",
"# from diso import DiffDMC # Removido - não disponível"
)
inference_utils_path.write_text(content)
print("Patch aplicado!")
# Adicionar ao path
sys.path.insert(0, str(TRIPOSG_REPO))
sys.path.insert(0, str(TRIPOSG_REPO / "scripts"))
import gradio as gr
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
# Configurar caminhos
WEIGHTS_DIR = Path("pretrained_weights")
WEIGHTS_DIR.mkdir(exist_ok=True)
def download_models():
"""Baixa modelos se necessário."""
triposg_dir = WEIGHTS_DIR / "TripoSG"
rmbg_dir = WEIGHTS_DIR / "RMBG-1.4"
if not (triposg_dir / "model_index.json").exists():
print("Baixando TripoSG weights...")
snapshot_download(
repo_id="VAST-AI/TripoSG",
local_dir=str(triposg_dir),
)
if not (rmbg_dir / "model.safetensors").exists():
print("Baixando RMBG-1.4...")
snapshot_download(
repo_id="briaai/RMBG-1.4",
local_dir=str(rmbg_dir),
)
return triposg_dir, rmbg_dir
# Baixar modelos na inicialização
print("Verificando modelos...")
TRIPOSG_WEIGHTS_DIR, RMBG_DIR = download_models()
# Importar TripoSG após download
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
from briarmbg import BriaRMBG
# Determinar dispositivo
# IMPORTANTE: Usar float32 mesmo em CUDA para evitar erro de dtype mismatch
# na função hierarchical_extract_geometry (mat1 Float vs mat2 Half)
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.float32 # float16 causa erro de dtype mismatch
elif torch.backends.mps.is_available():
DEVICE = "mps"
DTYPE = torch.float32
else:
DEVICE = "cpu"
DTYPE = torch.float32
print(f"Usando dispositivo: {DEVICE}")
# Carregar modelos
print("Carregando RMBG...")
rmbg_net = BriaRMBG.from_pretrained(str(RMBG_DIR)).to(DEVICE)
rmbg_net.eval()
print("Carregando TripoSG...")
pipe = TripoSGPipeline.from_pretrained(
str(TRIPOSG_WEIGHTS_DIR),
torch_dtype=DTYPE,
)
pipe = pipe.to(DEVICE)
if hasattr(pipe, "enable_attention_slicing"):
pipe.enable_attention_slicing("max")
print("Modelos carregados!")
def remove_background(image: Image.Image) -> Image.Image:
"""Remove fundo da imagem."""
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
result = rmbg_net(input_tensor)
mask = result[0][0].squeeze().cpu().numpy()
mask = (mask * 255).astype(np.uint8)
mask = Image.fromarray(mask).resize(image.size)
# Aplicar máscara
image_rgba = image.convert("RGBA")
image_rgba.putalpha(mask)
# Criar fundo branco
background = Image.new("RGBA", image.size, (255, 255, 255, 255))
composite = Image.alpha_composite(background, image_rgba)
return composite.convert("RGB")
def generate_3d(
image: Image.Image,
num_steps: int = 50,
guidance_scale: float = 7.0,
seed: int = 42,
remove_bg: bool = True,
output_format: str = "glb",
progress=gr.Progress(),
):
"""Gera modelo 3D a partir de imagem."""
if image is None:
raise gr.Error("Por favor, faça upload de uma imagem.")
progress(0.1, desc="Preparando imagem...")
# Remover fundo se necessário
if remove_bg:
progress(0.2, desc="Removendo fundo...")
image = remove_background(image)
# Redimensionar para 512x512
image = image.resize((512, 512), Image.LANCZOS)
progress(0.3, desc="Gerando modelo 3D...")
# Gerar
generator = torch.Generator(device=DEVICE).manual_seed(seed)
with torch.no_grad():
outputs = pipe(
image=image,
generator=generator,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
use_flash_decoder=False, # Evita dependência diso
)
progress(0.9, desc="Salvando modelo...")
# Salvar mesh
mesh = outputs.meshes[0]
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", delete=False) as f:
output_path = f.name
if output_format == "glb":
mesh.export(output_path)
else: # stl
mesh.export(output_path, file_type="stl")
progress(1.0, desc="Concluído!")
return output_path
# Interface Gradio - usando Interface simples para evitar bugs de schema
def generate_wrapper(image, num_steps, guidance_scale, seed, remove_bg, output_format):
"""Wrapper para a função generate_3d sem progress."""
if image is None:
raise gr.Error("Por favor, faça upload de uma imagem.")
# Remover fundo se necessário
if remove_bg:
image = remove_background(image)
# Redimensionar para 512x512
image = image.resize((512, 512), Image.LANCZOS)
# Gerar
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
with torch.no_grad():
outputs = pipe(
image=image,
generator=generator,
num_inference_steps=int(num_steps),
guidance_scale=float(guidance_scale),
use_flash_decoder=False,
)
# Salvar mesh
mesh = outputs.meshes[0]
with tempfile.NamedTemporaryFile(suffix=f".{output_format}", delete=False) as f:
output_path = f.name
if output_format == "glb":
mesh.export(output_path)
else:
mesh.export(output_path, file_type="stl")
return output_path
demo = gr.Interface(
fn=generate_wrapper,
inputs=[
gr.Image(label="Imagem de Entrada", type="pil"),
gr.Slider(label="Passos de inferência", minimum=20, maximum=100, value=50, step=5),
gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, value=7.0, step=0.5),
gr.Number(label="Seed", value=42, precision=0),
gr.Checkbox(label="Remover fundo automaticamente", value=True),
gr.Radio(label="Formato de saída", choices=["glb", "stl"], value="glb"),
],
outputs=gr.File(label="Modelo 3D Gerado"),
title="🎨 TripoSG 3D Generator",
description="""
Gere modelos 3D a partir de imagens usando [TripoSG](https://github.com/VAST-AI-Research/TripoSG).
**Dicas:** Use imagens com objeto centralizado. Mais passos = melhor qualidade (mais lento).
Baixe o arquivo e visualize em [glTF Viewer](https://gltf-viewer.donmccurdy.com/)
""",
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)