Nekochu's picture
init
5818455 verified
"""
TEXTure CPU Lite - Text-Guided 3D Texturing
Single-file implementation with CPU renderer and xatlas UV unwrapping.
"""
import os
import copy
import tempfile
import shutil
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import trimesh
import gradio as gr
from PIL import Image
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
# =============================================================================
# CONFIGURATION
# =============================================================================
SD_MODEL = "radames/stable-diffusion-2-depth-img2img" # Public copy of SD-2-Depth (original is gated)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
NUM_VIEWS = 4
RENDER_SIZE = 512
TEXTURE_RESOLUTION = 1024
NUM_INFERENCE_STEPS = 20 # Full quality (slower on CPU but better results)
# =============================================================================
# MESH CLASS (replaces Kaolin)
# =============================================================================
class Mesh:
"""CPU-compatible mesh class using trimesh."""
def __init__(self, obj_path: str, device: str = "cpu"):
mesh = trimesh.load(obj_path, force='mesh', process=False)
if not isinstance(mesh, trimesh.Trimesh):
raise ValueError(f"Failed to load mesh from {obj_path}")
self.vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
self.faces = torch.tensor(mesh.faces, dtype=torch.long, device=device)
self.normals, self.face_area = self._calc_normals(self.vertices, self.faces)
self.ft = None
self.vt = None
if hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
uv = mesh.visual.uv
if uv is not None and len(uv) > 0:
self.vt = torch.tensor(uv, dtype=torch.float32, device=device)
self.ft = self.faces.clone()
@staticmethod
def _calc_normals(vertices, faces):
v0, v1, v2 = vertices[faces[:, 0]], vertices[faces[:, 1]], vertices[faces[:, 2]]
n = torch.cross(v1 - v0, v2 - v0, dim=-1)
area = torch.norm(n, dim=-1)
n = n / (area[:, None] + 1e-8)
return n, area / 2
def normalize_mesh(self, inplace=False, target_scale=1.0, dy=0.0):
mesh = self if inplace else copy.deepcopy(self)
verts = mesh.vertices
center = verts.mean(dim=0)
verts = verts - center
scale = torch.max(torch.norm(verts, p=2, dim=1))
verts = verts / (scale + 1e-8) * target_scale
verts[:, 1] = verts[:, 1] + dy
mesh.vertices = verts
return mesh
# =============================================================================
# RENDERER (replaces Kaolin render functions)
# =============================================================================
def perspective_projection(fov=np.pi/3, aspect=1.0, near=0.1, far=100.0):
f = 1.0 / np.tan(fov / 2)
proj = torch.zeros(4, 4)
proj[0, 0] = f / aspect
proj[1, 1] = f
proj[2, 2] = (far + near) / (near - far)
proj[2, 3] = (2 * far * near) / (near - far)
proj[3, 2] = -1.0
return proj
def view_matrix(pos, look_at, up):
pos, look_at, up = pos.squeeze(), look_at.squeeze(), up.squeeze()
forward = (look_at - pos) / (torch.norm(look_at - pos) + 1e-8)
right = torch.linalg.cross(forward, up)
right = right / (torch.norm(right) + 1e-8)
new_up = torch.linalg.cross(right, forward)
view = torch.eye(4)
view[0, :3], view[1, :3], view[2, :3] = right, new_up, -forward
view[0, 3] = -torch.dot(right, pos)
view[1, 3] = -torch.dot(new_up, pos)
view[2, 3] = torch.dot(forward, pos)
return view.unsqueeze(0)
def camera_from_angles(elev, azim, r=3.0, look_at_height=0.0):
x = r * torch.sin(elev) * torch.sin(azim)
y = r * torch.cos(elev)
z = r * torch.sin(elev) * torch.cos(azim)
pos = torch.tensor([[x, y, z]])
look_at = torch.zeros_like(pos)
look_at[:, 1] = look_at_height
return view_matrix(pos, look_at, torch.tensor([[0.0, 1.0, 0.0]]))
def prepare_vertices(vertices, faces, proj, view):
device = vertices.device
face_verts = vertices[faces.long()]
ones = torch.ones(*face_verts.shape[:-1], 1, device=device)
face_verts_h = torch.cat([face_verts, ones], dim=-1)
view_mat = view.squeeze(0).to(device)
face_verts_cam = torch.einsum('ij,fvj->fvi', view_mat, face_verts_h)
proj_mat = proj.to(device)
face_verts_clip = torch.einsum('ij,fvj->fvi', proj_mat, face_verts_cam)
w = face_verts_clip[..., 3:4].clamp(min=1e-8)
face_verts_ndc = face_verts_clip[..., :3] / w
face_verts_img = face_verts_ndc[..., :2]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
normals = normals / (torch.norm(normals, dim=-1, keepdim=True) + 1e-8)
return face_verts_cam.unsqueeze(0), face_verts_img.unsqueeze(0), normals.unsqueeze(0)
def rasterize(width, height, face_z, face_verts_img, face_attrs):
device = face_verts_img.device
num_faces = face_verts_img.shape[1]
num_attrs = face_attrs.shape[-1]
features = torch.zeros(1, height, width, num_attrs, device=device)
face_idx = torch.full((1, height, width, 1), -1, dtype=torch.long, device=device)
depth_buf = torch.full((1, height, width), float('inf'), device=device)
verts_pix = face_verts_img.clone()
verts_pix[..., 0] = (verts_pix[..., 0] + 1) * 0.5 * width
verts_pix[..., 1] = (1 - verts_pix[..., 1]) * 0.5 * height
for f in range(num_faces):
v0, v1, v2 = verts_pix[0, f, 0], verts_pix[0, f, 1], verts_pix[0, f, 2]
z0, z1, z2 = face_z[0, f, 0], face_z[0, f, 1], face_z[0, f, 2]
a0, a1, a2 = face_attrs[0, f, 0], face_attrs[0, f, 1], face_attrs[0, f, 2]
min_x = max(0, int(torch.floor(torch.min(torch.stack([v0[0], v1[0], v2[0]]))).item()))
max_x = min(width - 1, int(torch.ceil(torch.max(torch.stack([v0[0], v1[0], v2[0]]))).item()))
min_y = max(0, int(torch.floor(torch.min(torch.stack([v0[1], v1[1], v2[1]]))).item()))
max_y = min(height - 1, int(torch.ceil(torch.max(torch.stack([v0[1], v1[1], v2[1]]))).item()))
if min_x > max_x or min_y > max_y:
continue
px = torch.arange(min_x, max_x + 1, device=device).float() + 0.5
py = torch.arange(min_y, max_y + 1, device=device).float() + 0.5
px_grid, py_grid = torch.meshgrid(px, py, indexing='xy')
points = torch.stack([px_grid.flatten(), py_grid.flatten()], dim=-1)
def edge_fn(va, vb, p):
return (p[..., 0] - va[0]) * (vb[1] - va[1]) - (p[..., 1] - va[1]) * (vb[0] - va[0])
area = edge_fn(v0, v1, v2)
if abs(area.item()) < 1e-8:
continue
w0 = edge_fn(v1, v2, points) / area
w1 = edge_fn(v2, v0, points) / area
w2 = edge_fn(v0, v1, points) / area
inside = (w0 >= 0) & (w1 >= 0) & (w2 >= 0)
if not inside.any():
continue
idx = torch.where(inside)[0]
pts, iw0, iw1, iw2 = points[idx], w0[idx], w1[idx], w2[idx]
interp_z = iw0 * z0 + iw1 * z1 + iw2 * z2
interp_attr = iw0.unsqueeze(-1) * a0 + iw1.unsqueeze(-1) * a1 + iw2.unsqueeze(-1) * a2
pix_x, pix_y = pts[:, 0].long(), pts[:, 1].long()
for i in range(len(idx)):
x, y, z = pix_x[i].item(), pix_y[i].item(), interp_z[i].item()
if z < depth_buf[0, y, x].item():
depth_buf[0, y, x] = z
features[0, y, x] = interp_attr[i]
face_idx[0, y, x, 0] = f
return features, face_idx
def texture_sample(uv, texture, mode='bilinear'):
grid = uv.clone()
grid[..., 0] = grid[..., 0] * 2 - 1
grid[..., 1] = (1 - grid[..., 1]) * 2 - 1
sampled = F.grid_sample(texture, grid, mode=mode, padding_mode='border', align_corners=False)
return sampled.permute(0, 2, 3, 1).unsqueeze(1)
# =============================================================================
# TEXTURED MESH MODEL
# =============================================================================
@dataclass
class MeshConfig:
shape_path: str = 'shapes/bunny.obj'
shape_scale: float = 0.6
dy: float = 0.25
texture_resolution: int = 512
class TexturedMeshModel(nn.Module):
def __init__(self, config: MeshConfig, render_size=256, cache_path=None, device='cpu'):
super().__init__()
self.device = device
self.config = config
self.dy = config.dy
self.mesh_scale = config.shape_scale
self.texture_res = config.texture_resolution
self.cache_path = cache_path
self.proj = perspective_projection(np.pi / 3)
self.mesh = Mesh(config.shape_path, device).normalize_mesh(True, config.shape_scale, config.dy)
texture = torch.ones(1, 3, self.texture_res, self.texture_res, device=device)
self.texture_img = nn.Parameter(texture)
self.vt, self.ft = self._init_uv()
self.face_attrs = self.vt.unsqueeze(0)[:, self.ft.long()]
def _init_uv(self):
if self.cache_path:
vt_path = Path(self.cache_path) / 'vt.pth'
ft_path = Path(self.cache_path) / 'ft.pth'
if vt_path.exists() and ft_path.exists():
return torch.load(vt_path).to(self.device), torch.load(ft_path).to(self.device)
if self.mesh.vt is not None and self.mesh.vt.shape[0] > 0:
return self.mesh.vt.to(self.device), self.mesh.ft.to(self.device)
import xatlas
v_np = self.mesh.vertices.cpu().numpy()
f_np = self.mesh.faces.int().cpu().numpy()
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
opts = xatlas.ChartOptions()
opts.max_iterations = 4
atlas.generate(chart_options=opts)
_, ft_np, vt_np = atlas[0]
vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
ft = torch.from_numpy(ft_np.astype(np.int64)).to(self.device)
if self.cache_path:
os.makedirs(self.cache_path, exist_ok=True)
torch.save(vt.cpu(), Path(self.cache_path) / 'vt.pth')
torch.save(ft.cpu(), Path(self.cache_path) / 'ft.pth')
return vt, ft
def render(self, theta, phi, radius, dims=None):
dims = dims or (RENDER_SIZE, RENDER_SIZE)
cam = camera_from_angles(torch.tensor(theta), torch.tensor(phi), radius, self.dy)
verts_cam, verts_img, normals = prepare_vertices(
self.mesh.vertices, self.mesh.faces, self.proj, cam)
depth_attr = verts_cam[:, :, :, -1:]
depth, _ = rasterize(dims[1], dims[0], verts_cam[:, :, :, -1], verts_img, depth_attr)
mask_d = depth != 0
if mask_d.any():
d_min, d_max = depth[mask_d].min(), depth[mask_d].max()
if d_max > d_min:
depth[mask_d] = 0.5 + 0.5 * (depth[mask_d] - d_min) / (d_max - d_min)
uv_feats, face_idx = rasterize(dims[1], dims[0], verts_cam[:, :, :, -1], verts_img, self.face_attrs)
mask = (face_idx > -1).float()
img_feats = texture_sample(uv_feats, self.texture_img).squeeze(1) * mask
img_feats = img_feats + (1 - mask)
return {
'image': img_feats.permute(0, 3, 1, 2).clamp(0, 1),
'mask': mask.permute(0, 3, 1, 2),
'depth': depth.permute(0, 3, 1, 2),
'render_cache': {'uv_features': uv_feats, 'face_idx': face_idx}
}
def export_mesh(self, path, name=''):
os.makedirs(path, exist_ok=True)
v_np = self.mesh.vertices.cpu().numpy()
f_np = self.mesh.faces.int().cpu().numpy()
vt_np = self.vt.cpu().numpy()
ft_np = self.ft.cpu().numpy()
tex = self.texture_img.permute(0, 2, 3, 1).clamp(0, 1)[0].detach().cpu().numpy()
Image.fromarray((tex * 255).astype(np.uint8)).save(f'{path}/{name}albedo.png')
with open(f'{path}/{name}mesh.obj', 'w') as fp:
fp.write(f'mtllib {name}mesh.mtl\n')
for v in v_np:
fp.write(f'v {v[0]} {v[1]} {v[2]}\n')
for v in vt_np:
fp.write(f'vt {v[0]} {v[1]}\n')
fp.write('usemtl mat0\n')
for i in range(len(f_np)):
fp.write(f"f {f_np[i,0]+1}/{ft_np[i,0]+1} {f_np[i,1]+1}/{ft_np[i,1]+1} {f_np[i,2]+1}/{ft_np[i,2]+1}\n")
with open(f'{path}/{name}mesh.mtl', 'w') as fp:
fp.write('newmtl mat0\nKa 1 1 1\nKd 1 1 1\nKs 0 0 0\nillum 1\n')
fp.write(f'map_Kd {name}albedo.png\n')
# =============================================================================
# SD PIPELINE (PyTorch + INT8 Quantization)
# =============================================================================
# NOTE: ONNX doesn't support Depth2Img pipeline (5 channels vs 4)
# Using PyTorch with INT8 quantization instead
sd_pipe = None
def load_pipeline():
global sd_pipe
if sd_pipe is not None:
return sd_pipe
print("\n[INFO] Loading SD-2-Depth pipeline (PyTorch + INT8)...")
print("[INFO] Note: ONNX not supported for Depth2Img (5-channel UNet)")
from diffusers import StableDiffusionDepth2ImgPipeline
try:
print(f"[1/2] Downloading {SD_MODEL}...")
sd_pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
SD_MODEL,
torch_dtype=DTYPE,
)
print("[OK] Model downloaded")
# Quantize on CPU for faster inference
if DEVICE == "cpu":
try:
from optimum.quanto import quantize, freeze, qint8
print("[2/2] Applying INT8 quantization to UNet...")
quantize(sd_pipe.unet, weights=qint8)
freeze(sd_pipe.unet)
print("[OK] INT8 quantization applied (3-5x faster than FP32)")
except ImportError:
print("[WARN] optimum.quanto not available, using FP32 (slower)")
sd_pipe = sd_pipe.to(DEVICE)
# Disable autocast on CPU to avoid color issues
if DEVICE == "cpu":
sd_pipe.set_progress_bar_config(disable=False)
# Force FP32 for VAE to prevent color artifacts
sd_pipe.vae = sd_pipe.vae.float()
print("[OK] Pipeline ready!")
return sd_pipe
except Exception as e:
print(f"[ERROR] Pipeline loading failed: {e}")
if "401" in str(e) or "token" in str(e).lower():
print("[ERROR] Authentication required. Set HF_TOKEN environment variable")
raise
# =============================================================================
# MAIN PIPELINE
# =============================================================================
def dilate_texture(tex, mask, iterations=50):
"""Dilate texture to fill gaps using scipy morphological operations (fast)."""
from scipy import ndimage
result = tex.clone().detach().numpy()
filled = mask.clone().detach().numpy().astype(bool)
# Use scipy binary_dilation for speed
for _ in range(iterations):
if filled.all():
break
# Find boundary pixels (unfilled with filled neighbors)
dilated_mask = ndimage.binary_dilation(filled)
boundary = dilated_mask & ~filled
if not boundary.any():
break
# For each boundary pixel, average from filled neighbors
for c in range(3):
# Compute neighbor average using convolution
kernel = np.array([[1,1,1],[1,0,1],[1,1,1]], dtype=np.float32)
neighbor_sum = ndimage.convolve(result[c] * filled, kernel, mode='constant')
neighbor_count = ndimage.convolve(filled.astype(np.float32), kernel, mode='constant')
# Update boundary pixels
valid = boundary & (neighbor_count > 0)
result[c][valid] = neighbor_sum[valid] / neighbor_count[valid]
filled = filled | boundary
return torch.from_numpy(result).float(), torch.from_numpy(filled)
def project_to_texture(tex, gen_img, uv, mask, blend=0.7, uv_mask=None):
"""Project generated image to UV texture using scipy interpolation."""
from scipy.interpolate import griddata
_, _, TH, TW = tex.shape
new_tex = tex.clone()
# Flatten arrays
mask_f = mask[0, 0].reshape(-1).detach().numpy()
uv_f = uv[0].reshape(-1, 2).detach().numpy()
gen_np = gen_img[0].permute(1, 2, 0).reshape(-1, 3).detach().numpy()
# Get visible pixels
vis = mask_f > 0.5
if vis.sum() < 10:
return new_tex, uv_mask
# UV coords of visible pixels (source points)
src_uv = uv_f[vis] # N x 2
src_colors = gen_np[vis] # N x 3
# Target UV grid (destination)
tx = np.linspace(0, 1, TW)
ty = np.linspace(0, 1, TH)
grid_x, grid_y = np.meshgrid(tx, ty)
# Flip V coordinate
src_uv_flipped = src_uv.copy()
src_uv_flipped[:, 1] = 1 - src_uv_flipped[:, 1]
# Interpolate each channel
proj_tex = np.zeros((TH, TW, 3), dtype=np.float32)
for c in range(3):
proj_tex[:, :, c] = griddata(
src_uv_flipped, src_colors[:, c],
(grid_x, grid_y), method='linear', fill_value=np.nan
)
# Create mask of valid (non-NaN) pixels
proj_mask = ~np.isnan(proj_tex[:, :, 0])
proj_tex = np.nan_to_num(proj_tex, nan=0.5)
# Track cumulative UV coverage
if uv_mask is None:
uv_mask = torch.zeros(TH, TW, dtype=torch.bool)
proj_mask_t = torch.from_numpy(proj_mask)
proj_tex_t = torch.from_numpy(proj_tex).permute(2, 0, 1).float()
# Blend
new_pixels = proj_mask_t & ~uv_mask
existing_pixels = proj_mask_t & uv_mask
for c in range(3):
new_tex[0, c][new_pixels] = proj_tex_t[c][new_pixels]
new_tex[0, c][existing_pixels] = blend * proj_tex_t[c][existing_pixels] + (1 - blend) * new_tex[0, c][existing_pixels]
uv_mask = uv_mask | proj_mask_t
return new_tex, uv_mask
def finalize_texture(tex, uv_mask, iterations=100):
"""Fill remaining gaps in texture using dilation."""
# Extract texture as numpy
tex_np = tex[0].clone()
# Dilate to fill gaps
dilated, filled = dilate_texture(tex_np, uv_mask, iterations=iterations)
# Put back
result = tex.clone()
result[0] = dilated
return result
def generate_texture(mesh_file, prompt, num_views, num_steps, seed, progress=gr.Progress()):
if mesh_file is None:
raise gr.Error("Please upload a mesh file!")
if not prompt.strip():
raise gr.Error("Please enter a text prompt!")
temp_dir = tempfile.mkdtemp()
try:
mesh_ext = os.path.splitext(mesh_file)[1].lower()
mesh_path = os.path.join(temp_dir, f"mesh{mesh_ext}")
shutil.copy(mesh_file, mesh_path)
progress(0.1, desc="Creating UV map...")
config = MeshConfig(shape_path=mesh_path, texture_resolution=TEXTURE_RESOLUTION)
model = TexturedMeshModel(config, RENDER_SIZE, Path(temp_dir) / 'cache', 'cpu')
progress(0.2, desc="Loading SD-2-Depth...")
pipe = load_pipeline()
viewpoints = [(0.5, 0.0), (0.5, np.pi/2), (0.5, np.pi), (0.5, -np.pi/2), (0.2, 0.0), (0.8, 0.0)][:num_views]
with torch.no_grad():
model.texture_img.fill_(0.5) # Start with neutral gray instead of white
previews = []
for i, (theta, phi) in enumerate(viewpoints):
progress(0.3 + 0.5 * i / len(viewpoints), desc=f"View {i+1}/{len(viewpoints)}...")
result = model.render(theta, phi, 2.0, (RENDER_SIZE, RENDER_SIZE))
depth = result['depth'][0, 0].cpu().numpy()
mask = result['mask'][0, 0].cpu().numpy()
if mask.sum() > 0:
d_vis = depth[mask > 0]
d_min, d_max = d_vis.min(), d_vis.max()
if d_max > d_min:
depth = (depth - d_min) / (d_max - d_min)
depth = depth * mask
depth_img = Image.fromarray((np.clip(depth, 0, 1) * 255).astype(np.uint8)).convert('RGB')
gen = torch.Generator(device=DEVICE).manual_seed(int(seed)) # Same seed for consistency
# SD-2-Depth: native depth conditioning (same as original TEXTure)
steps = int(num_steps) if num_steps else NUM_INFERENCE_STEPS
direction = ["front", "right side", "back", "left side"][i % 4]
textured = pipe(
prompt=f"{prompt}, {direction} view, consistent style",
image=depth_img,
strength=0.85, # Slightly less strength for more depth adherence
num_inference_steps=steps,
guidance_scale=7.5,
generator=gen
).images[0]
previews.append(textured)
uv = result['render_cache']['uv_features']
gen_t = torch.tensor(np.array(textured)).float().permute(2, 0, 1).unsqueeze(0) / 255.0
with torch.no_grad():
# Track UV coverage across views
if i == 0:
uv_mask = None
model.texture_img.data, uv_mask = project_to_texture(
model.texture_img, gen_t, uv, result['mask'],
blend=0.5, uv_mask=uv_mask
)
progress(0.85, desc="Filling gaps...")
# Final dilation to fill any remaining gaps
with torch.no_grad():
model.texture_img.data = finalize_texture(model.texture_img, uv_mask, iterations=150)
progress(0.9, desc="Saving...")
tex_np = model.texture_img[0].permute(1, 2, 0).clamp(0, 1).detach().numpy()
tex_img = Image.fromarray((tex_np * 255).astype(np.uint8))
tex_img.save(f'{temp_dir}/uv_texture.png')
# Render 3D preview with texture
preview_result = model.render(0.4, 0.3, 2.5, (512, 512))
preview_np = preview_result['image'][0].permute(1, 2, 0).clamp(0, 1).detach().cpu().numpy()
preview_img = Image.fromarray((preview_np * 255).astype(np.uint8))
previews.insert(0, preview_img) # Add 3D preview as first image
model.export_mesh(f'{temp_dir}/mesh', '')
zip_path = f'{temp_dir}/textured_mesh.zip'
with zipfile.ZipFile(zip_path, 'w') as zf:
for f in ['mesh/albedo.png', 'mesh/mesh.obj', 'mesh/mesh.mtl', 'uv_texture.png']:
if os.path.exists(f'{temp_dir}/{f}'):
zf.write(f'{temp_dir}/{f}', os.path.basename(f))
progress(1.0, desc="Done!")
return tex_img, previews, zip_path
except Exception as e:
raise gr.Error(f"Error: {str(e)}")
# =============================================================================
# GRADIO UI
# =============================================================================
with gr.Blocks(title="TEXTure CPU Lite") as demo:
gr.Markdown("""# TEXTure CPU Lite
Generate UV texture maps for 3D meshes using text prompts.
⚠️ **Quality Notice:** This is a simplified CPU-only demo. Results are significantly worse than the [original TEXTure paper](https://texturepaper.github.io/TEXTurePaper/).
**Why it looks bad:**
- No Kaolin GPU rasterizer → using slow software renderer with lower precision
- No proper view weighting → seams between views are visible
- No texture inpainting → blotchy patches instead of smooth transitions
- No refinement passes → single-pass projection loses detail
- INT8 quantization on CPU → color artifacts possible
**For production quality:** Use the [original TEXTure repo](https://github.com/TEXTurePaper/TEXTurePaper) with a GPU.
""")
with gr.Row():
with gr.Column():
mesh_in = gr.File(label="3D Mesh (.obj, .stl, .ply, .glb)", file_types=[".obj", ".stl", ".ply", ".glb", ".off"])
prompt_in = gr.Textbox(label="Texture Prompt", placeholder="ceramic with blue and white pattern", lines=2)
with gr.Row():
views_in = gr.Slider(2, 6, value=4, step=1, label="Views")
steps_in = gr.Slider(5, 25, value=20, step=1, label="Steps (5=fast, 20=quality)")
with gr.Row():
seed_in = gr.Number(value=42, label="Seed", precision=0)
btn = gr.Button("Generate", variant="primary")
gr.Markdown("**CPU Time:** ~1.5 min/view @ 10 steps, ~3 min/view @ 20 steps")
with gr.Column():
tex_out = gr.Image(label="UV Texture", type="pil")
gallery_out = gr.Gallery(label="3D Preview + Generated Views", columns=2, height=250)
zip_out = gr.File(label="Download (ZIP)")
btn.click(generate_texture, [mesh_in, prompt_in, views_in, steps_in, seed_in], [tex_out, gallery_out, zip_out])
gr.Markdown("**Credits:** [TEXTure Paper](https://texturepaper.github.io/TEXTurePaper/), [SD-2-Depth](https://huggingface.co/radames/stable-diffusion-2-depth-img2img), [xatlas](https://github.com/jpcy/xatlas)")
if __name__ == "__main__":
demo.queue(max_size=2).launch(ssr_mode=False)