|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
import warnings |
|
|
from dataclasses import dataclass, field |
|
|
from shutil import copy, rmtree |
|
|
|
|
|
import torch |
|
|
import tyro |
|
|
from huggingface_hub import snapshot_download |
|
|
from packaging import version |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
logging.getLogger("diffusers").setLevel(logging.ERROR) |
|
|
|
|
|
|
|
|
if version.parse(torch.__version__) >= version.parse("0.16"): |
|
|
import sys |
|
|
import types |
|
|
|
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
functional_tensor = types.ModuleType( |
|
|
"torchvision.transforms.functional_tensor" |
|
|
) |
|
|
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale |
|
|
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor |
|
|
|
|
|
from gsplat.distributed import cli |
|
|
from txt2panoimg import Text2360PanoramaImagePipeline |
|
|
from embodied_gen.trainer.gsplat_trainer import ( |
|
|
DefaultStrategy, |
|
|
GsplatTrainConfig, |
|
|
) |
|
|
from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint |
|
|
from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline |
|
|
from embodied_gen.utils.config import Pano2MeshSRConfig |
|
|
from embodied_gen.utils.gaussian import restore_scene_scale_and_position |
|
|
from embodied_gen.utils.gpt_clients import GPT_CLIENT |
|
|
from embodied_gen.utils.log import logger |
|
|
from embodied_gen.utils.process_media import is_image_file, parse_text_prompts |
|
|
from embodied_gen.validators.quality_checkers import ( |
|
|
PanoHeightEstimator, |
|
|
PanoImageOccChecker, |
|
|
) |
|
|
|
|
|
__all__ = [ |
|
|
"generate_pano_image", |
|
|
"entrypoint", |
|
|
] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Scene3DGenConfig: |
|
|
prompts: list[str] |
|
|
output_dir: str |
|
|
seed: int | None = None |
|
|
real_height: float | None = None |
|
|
pano_image_only: bool = False |
|
|
disable_pano_check: bool = False |
|
|
keep_middle_result: bool = False |
|
|
n_retry: int = 7 |
|
|
gs3d: GsplatTrainConfig = field( |
|
|
default_factory=lambda: GsplatTrainConfig( |
|
|
strategy=DefaultStrategy(verbose=True), |
|
|
max_steps=4000, |
|
|
init_opa=0.9, |
|
|
opacity_reg=2e-3, |
|
|
sh_degree=0, |
|
|
means_lr=1e-4, |
|
|
scales_lr=1e-3, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def generate_pano_image( |
|
|
prompt: str, |
|
|
output_path: str, |
|
|
pipeline, |
|
|
seed: int, |
|
|
n_retry: int, |
|
|
checker=None, |
|
|
num_inference_steps: int = 40, |
|
|
) -> None: |
|
|
for i in range(n_retry): |
|
|
logger.info( |
|
|
f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}" |
|
|
) |
|
|
if is_image_file(prompt): |
|
|
raise NotImplementedError("Image mode not implemented yet.") |
|
|
else: |
|
|
txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture" |
|
|
inputs = { |
|
|
"prompt": txt_prompt, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"upscale": False, |
|
|
"seed": seed, |
|
|
} |
|
|
pano_image = pipeline(inputs) |
|
|
|
|
|
pano_image.save(output_path) |
|
|
if checker is None: |
|
|
break |
|
|
|
|
|
flag, response = checker(pano_image) |
|
|
logger.warning(f"{response}, image saved in {output_path}") |
|
|
if flag is True or flag is None: |
|
|
break |
|
|
|
|
|
seed = random.randint(0, 100000) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def entrypoint(*args, **kwargs): |
|
|
cfg = tyro.cli(Scene3DGenConfig) |
|
|
|
|
|
|
|
|
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage") |
|
|
IMG2PANO_PIPE = Text2360PanoramaImagePipeline( |
|
|
model_path, torch_dtype=torch.float16, device="cuda" |
|
|
) |
|
|
PANOMESH_CFG = Pano2MeshSRConfig() |
|
|
PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG) |
|
|
PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000]) |
|
|
PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT) |
|
|
|
|
|
prompts = parse_text_prompts(cfg.prompts) |
|
|
for idx, prompt in enumerate(prompts): |
|
|
start_time = time.time() |
|
|
output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
pano_path = os.path.join(output_dir, "pano_image.png") |
|
|
with open(f"{output_dir}/prompt.txt", "w") as f: |
|
|
f.write(prompt) |
|
|
|
|
|
generate_pano_image( |
|
|
prompt, |
|
|
pano_path, |
|
|
IMG2PANO_PIPE, |
|
|
cfg.seed if cfg.seed is not None else random.randint(0, 100000), |
|
|
cfg.n_retry, |
|
|
checker=None if cfg.disable_pano_check else PANO_CHECKER, |
|
|
) |
|
|
|
|
|
if cfg.pano_image_only: |
|
|
continue |
|
|
|
|
|
logger.info("GEN and REPAIR Mesh from Panorama...") |
|
|
PANO2MESH_PIPE(pano_path, output_dir) |
|
|
|
|
|
logger.info("TRAIN 3DGS from Mesh Init and Cube Image...") |
|
|
cfg.gs3d.data_dir = output_dir |
|
|
cfg.gs3d.result_dir = f"{output_dir}/gaussian" |
|
|
cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler) |
|
|
torch.set_default_device("cpu") |
|
|
cli(gsplat_entrypoint, cfg.gs3d, verbose=True) |
|
|
|
|
|
|
|
|
gs_path = ( |
|
|
f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply" |
|
|
) |
|
|
copy(gs_path, f"{output_dir}/gs_model.ply") |
|
|
video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4" |
|
|
copy(video_path, f"{output_dir}/video.mp4") |
|
|
gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml" |
|
|
copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml") |
|
|
if not cfg.keep_middle_result: |
|
|
rmtree(cfg.gs3d.result_dir, ignore_errors=True) |
|
|
os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}") |
|
|
|
|
|
real_height = ( |
|
|
PANOHEIGHT_ESTOR(pano_path) |
|
|
if cfg.real_height is None |
|
|
else cfg.real_height |
|
|
) |
|
|
gs_path = os.path.join(output_dir, "gs_model.ply") |
|
|
mesh_path = os.path.join(output_dir, "mesh_model.ply") |
|
|
restore_scene_scale_and_position(real_height, mesh_path, gs_path) |
|
|
|
|
|
elapsed_time = (time.time() - start_time) / 60 |
|
|
logger.info( |
|
|
f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins." |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
entrypoint() |
|
|
|