|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import threading |
|
|
import argparse |
|
|
import tempfile |
|
|
import shutil |
|
|
from typing import Generator, Optional, Tuple |
|
|
import logging |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from gradio_models import GradioGaussianSplatting2D, StreamingResults |
|
|
from utils.misc_utils import load_cfg |
|
|
from main import get_log_dir |
|
|
|
|
|
|
|
|
class TrainingState: |
|
|
"""Manages the state of training sessions""" |
|
|
|
|
|
def __init__(self): |
|
|
self.is_training = False |
|
|
self.training_thread = None |
|
|
self.model = None |
|
|
self.temp_dir = None |
|
|
self.results = StreamingResults() |
|
|
|
|
|
def reset(self): |
|
|
self.is_training = False |
|
|
if self.temp_dir and os.path.exists(self.temp_dir): |
|
|
shutil.rmtree(self.temp_dir) |
|
|
self.temp_dir = None |
|
|
self.results = StreamingResults() |
|
|
|
|
|
|
|
|
|
|
|
training_state = TrainingState() |
|
|
|
|
|
|
|
|
def ensure_models_available(): |
|
|
"""Download models from HuggingFace if they're not available locally""" |
|
|
required_files = [ |
|
|
"models/emlnet/res_decoder.pth", |
|
|
"models/emlnet/res_imagenet.pth", |
|
|
"models/emlnet/res_places.pth", |
|
|
"models/torch/checkpoints/alexnet-owt-7be5be79.pth", |
|
|
] |
|
|
|
|
|
|
|
|
all_files_exist = all(os.path.exists(file_path) for file_path in required_files) |
|
|
|
|
|
if not all_files_exist: |
|
|
print("π₯ Downloading model files from HuggingFace...") |
|
|
try: |
|
|
|
|
|
os.makedirs("models", exist_ok=True) |
|
|
|
|
|
|
|
|
model_files_remote = [ |
|
|
"emlnet/res_decoder.pth", |
|
|
"emlnet/res_imagenet.pth", |
|
|
"emlnet/res_places.pth", |
|
|
"torch/checkpoints/alexnet-owt-7be5be79.pth", |
|
|
] |
|
|
|
|
|
model_files_local = [ |
|
|
"models/emlnet/res_decoder.pth", |
|
|
"models/emlnet/res_imagenet.pth", |
|
|
"models/emlnet/res_places.pth", |
|
|
"models/torch/checkpoints/alexnet-owt-7be5be79.pth", |
|
|
] |
|
|
|
|
|
for remote_file, local_file in zip(model_files_remote, model_files_local): |
|
|
if not os.path.exists(local_file): |
|
|
|
|
|
os.makedirs(os.path.dirname(local_file), exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"π₯ Downloading {remote_file} -> {local_file}...") |
|
|
downloaded_path = hf_hub_download( |
|
|
repo_id="blanchon/image-gs-models-utils", |
|
|
filename=remote_file, |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
|
|
|
shutil.copy2(downloaded_path, local_file) |
|
|
|
|
|
print("β
Model files downloaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"β Failed to download model files: {e}") |
|
|
print("β οΈ The app may not work properly without these model files.") |
|
|
else: |
|
|
print("β
Model files are already available locally.") |
|
|
|
|
|
|
|
|
|
|
|
ensure_models_available() |
|
|
torch.hub.set_dir("models/torch") |
|
|
|
|
|
|
|
|
def create_args_from_config( |
|
|
image_path: str, |
|
|
exp_name: str, |
|
|
num_gaussians: int, |
|
|
quantize: bool, |
|
|
pos_bits: int, |
|
|
scale_bits: int, |
|
|
rot_bits: int, |
|
|
feat_bits: int, |
|
|
init_mode: str, |
|
|
init_random_ratio: float, |
|
|
max_steps: int, |
|
|
vis_gaussians: bool, |
|
|
save_image_steps: int, |
|
|
l1_loss_ratio: float, |
|
|
l2_loss_ratio: float, |
|
|
ssim_loss_ratio: float, |
|
|
pos_lr: float, |
|
|
scale_lr: float, |
|
|
rot_lr: float, |
|
|
feat_lr: float, |
|
|
disable_lr_schedule: bool, |
|
|
disable_prog_optim: bool, |
|
|
) -> argparse.Namespace: |
|
|
"""Create arguments object from Gradio inputs""" |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser) |
|
|
args = parser.parse_args([]) |
|
|
|
|
|
|
|
|
args.input_path = image_path |
|
|
args.exp_name = exp_name |
|
|
args.num_gaussians = num_gaussians |
|
|
args.quantize = quantize |
|
|
args.pos_bits = pos_bits |
|
|
args.scale_bits = scale_bits |
|
|
args.rot_bits = rot_bits |
|
|
args.feat_bits = feat_bits |
|
|
args.init_mode = init_mode |
|
|
args.init_random_ratio = init_random_ratio |
|
|
args.max_steps = max_steps |
|
|
args.vis_gaussians = vis_gaussians |
|
|
args.save_image_steps = save_image_steps |
|
|
args.l1_loss_ratio = l1_loss_ratio |
|
|
args.l2_loss_ratio = l2_loss_ratio |
|
|
args.ssim_loss_ratio = ssim_loss_ratio |
|
|
args.pos_lr = pos_lr |
|
|
args.scale_lr = scale_lr |
|
|
args.rot_lr = rot_lr |
|
|
args.feat_lr = feat_lr |
|
|
args.disable_lr_schedule = disable_lr_schedule |
|
|
args.disable_prog_optim = disable_prog_optim |
|
|
args.eval = False |
|
|
|
|
|
|
|
|
args.log_dir = get_log_dir(args) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=300) |
|
|
def train_model(args: argparse.Namespace) -> None: |
|
|
"""Training function that runs with ZeroGPU allocation""" |
|
|
try: |
|
|
|
|
|
training_state.model = GradioGaussianSplatting2D(args, training_state.results) |
|
|
|
|
|
|
|
|
training_state.model.optimize() |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
|
|
|
training_state.results.training_logs.append(f"ERROR: {str(e)}") |
|
|
training_state.results.training_logs.append( |
|
|
f"TRACEBACK: {traceback.format_exc()}" |
|
|
) |
|
|
logging.error(f"Training failed: {str(e)}") |
|
|
logging.error(f"TRACEBACK: {traceback.format_exc()}") |
|
|
finally: |
|
|
training_state.is_training = False |
|
|
|
|
|
|
|
|
def start_training_and_stream( |
|
|
image_file, |
|
|
exp_name: str, |
|
|
num_gaussians: int, |
|
|
quantize: bool, |
|
|
pos_bits: int, |
|
|
scale_bits: int, |
|
|
rot_bits: int, |
|
|
feat_bits: int, |
|
|
init_mode: str, |
|
|
init_random_ratio: float, |
|
|
max_steps: int, |
|
|
vis_gaussians: bool, |
|
|
save_image_steps: int, |
|
|
l1_loss_ratio: float, |
|
|
l2_loss_ratio: float, |
|
|
ssim_loss_ratio: float, |
|
|
pos_lr: float, |
|
|
scale_lr: float, |
|
|
rot_lr: float, |
|
|
feat_lr: float, |
|
|
disable_lr_schedule: bool, |
|
|
disable_prog_optim: bool, |
|
|
) -> Generator[ |
|
|
Tuple[ |
|
|
str, |
|
|
str, |
|
|
Optional[Image.Image], |
|
|
Optional[Image.Image], |
|
|
Optional[Image.Image], |
|
|
bool, |
|
|
bool, |
|
|
], |
|
|
None, |
|
|
None, |
|
|
]: |
|
|
"""Start training and stream progress with images""" |
|
|
|
|
|
if training_state.is_training: |
|
|
yield ( |
|
|
"Training is already in progress!", |
|
|
"", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
False, |
|
|
True, |
|
|
) |
|
|
return |
|
|
|
|
|
if image_file is None: |
|
|
yield ( |
|
|
"Please upload an image first!", |
|
|
"", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
True, |
|
|
False, |
|
|
) |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
training_state.reset() |
|
|
|
|
|
|
|
|
training_state.temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
|
|
|
image_path = os.path.join(training_state.temp_dir, "input_image.png") |
|
|
image_file.save(image_path) |
|
|
|
|
|
|
|
|
args = create_args_from_config( |
|
|
image_path=image_path, |
|
|
exp_name=exp_name, |
|
|
num_gaussians=num_gaussians, |
|
|
quantize=quantize, |
|
|
pos_bits=pos_bits, |
|
|
scale_bits=scale_bits, |
|
|
rot_bits=rot_bits, |
|
|
feat_bits=feat_bits, |
|
|
init_mode=init_mode, |
|
|
init_random_ratio=init_random_ratio, |
|
|
max_steps=max_steps, |
|
|
vis_gaussians=vis_gaussians, |
|
|
save_image_steps=save_image_steps, |
|
|
l1_loss_ratio=l1_loss_ratio, |
|
|
l2_loss_ratio=l2_loss_ratio, |
|
|
ssim_loss_ratio=ssim_loss_ratio, |
|
|
pos_lr=pos_lr, |
|
|
scale_lr=scale_lr, |
|
|
rot_lr=rot_lr, |
|
|
feat_lr=feat_lr, |
|
|
disable_lr_schedule=disable_lr_schedule, |
|
|
disable_prog_optim=disable_prog_optim, |
|
|
) |
|
|
|
|
|
|
|
|
args.data_root = training_state.temp_dir |
|
|
args.input_path = "input_image.png" |
|
|
|
|
|
|
|
|
training_state.is_training = True |
|
|
training_state.training_thread = threading.Thread( |
|
|
target=train_model, args=(args,) |
|
|
) |
|
|
training_state.training_thread.start() |
|
|
|
|
|
|
|
|
yield ( |
|
|
"Training started! Check the progress below.", |
|
|
"Initializing training...", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
False, |
|
|
True, |
|
|
) |
|
|
|
|
|
|
|
|
while training_state.is_training or not training_state.results.is_complete: |
|
|
|
|
|
if ( |
|
|
not training_state.is_training |
|
|
and training_state.training_thread |
|
|
and training_state.training_thread.is_alive() |
|
|
): |
|
|
|
|
|
training_state.results.training_logs.append( |
|
|
"π Training stopped by user request" |
|
|
) |
|
|
break |
|
|
|
|
|
|
|
|
if training_state.results.training_logs: |
|
|
logs_text = "\n".join(training_state.results.training_logs) |
|
|
|
|
|
|
|
|
if training_state.results.step > 0: |
|
|
|
|
|
if training_state.results.step > training_state.results.total_steps: |
|
|
break |
|
|
|
|
|
metrics = training_state.results.metrics |
|
|
status_line = ( |
|
|
f"\nCurrent: Step {training_state.results.step}/{training_state.results.total_steps} | " |
|
|
f"PSNR: {metrics['psnr']:.2f} | SSIM: {metrics['ssim']:.4f} | " |
|
|
f"Loss: {metrics['loss']:.4f}" |
|
|
) |
|
|
logs_text += status_line |
|
|
else: |
|
|
logs_text = "Waiting for training to start..." |
|
|
|
|
|
|
|
|
initialization_map = training_state.results.initialization_map |
|
|
current_render = training_state.results.current_render |
|
|
current_gaussian_id = training_state.results.current_gaussian_id |
|
|
|
|
|
|
|
|
current_step = training_state.results.step |
|
|
if training_state.results.is_complete: |
|
|
status = "β
Training completed successfully!" |
|
|
start_btn_interactive = True |
|
|
stop_btn_interactive = False |
|
|
elif not training_state.is_training: |
|
|
status = "βΉοΈ Training stopped." |
|
|
start_btn_interactive = True |
|
|
stop_btn_interactive = False |
|
|
else: |
|
|
status = f"π Training in progress... Step {current_step}/{training_state.results.total_steps}" |
|
|
start_btn_interactive = False |
|
|
stop_btn_interactive = True |
|
|
|
|
|
|
|
|
yield ( |
|
|
status, |
|
|
logs_text, |
|
|
initialization_map, |
|
|
current_render, |
|
|
current_gaussian_id, |
|
|
start_btn_interactive, |
|
|
stop_btn_interactive, |
|
|
) |
|
|
|
|
|
|
|
|
if training_state.results.is_complete or not training_state.is_training: |
|
|
break |
|
|
if current_step > training_state.results.total_steps: |
|
|
break |
|
|
|
|
|
time.sleep(0.5) |
|
|
|
|
|
except Exception as e: |
|
|
training_state.reset() |
|
|
yield ( |
|
|
f"Failed to start training: {str(e)}", |
|
|
"", |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
True, |
|
|
False, |
|
|
) |
|
|
|
|
|
|
|
|
def stop_training() -> str: |
|
|
"""Stop the current training""" |
|
|
if not training_state.is_training: |
|
|
return "No training in progress." |
|
|
|
|
|
training_state.is_training = False |
|
|
training_state.results.training_logs.append( |
|
|
"π STOP: Training stop requested by user..." |
|
|
) |
|
|
|
|
|
|
|
|
if training_state.model: |
|
|
training_state.model.stop_requested = True |
|
|
|
|
|
return "Training stop requested. Will complete current step and stop." |
|
|
|
|
|
|
|
|
def get_final_results() -> Tuple[Optional[Image.Image], Optional[str]]: |
|
|
"""Get final training results""" |
|
|
final_render = training_state.results.final_render |
|
|
checkpoint_path = training_state.results.final_checkpoint_path |
|
|
return final_render, checkpoint_path |
|
|
|
|
|
|
|
|
def browse_step_results( |
|
|
step: int, |
|
|
) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: |
|
|
"""Browse results from a specific training step""" |
|
|
if not training_state.results.is_complete: |
|
|
return None, None |
|
|
|
|
|
|
|
|
available_steps = list(training_state.results.step_renders.keys()) |
|
|
if not available_steps: |
|
|
return None, None |
|
|
|
|
|
closest_step = min(available_steps, key=lambda x: abs(x - step)) |
|
|
|
|
|
render_img = training_state.results.step_renders.get(closest_step) |
|
|
gaussian_id_img = training_state.results.step_gaussian_ids.get(closest_step) |
|
|
|
|
|
return render_img, gaussian_id_img |
|
|
|
|
|
|
|
|
def update_step_slider_after_training() -> gr.Slider: |
|
|
"""Update step slider range and enable it after training completes""" |
|
|
if not training_state.results.is_complete: |
|
|
return gr.Slider( |
|
|
minimum=0, |
|
|
maximum=10000, |
|
|
value=0, |
|
|
step=100, |
|
|
label="Browse Training Steps", |
|
|
info="Training not complete yet", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
available_steps = list(training_state.results.step_renders.keys()) |
|
|
if not available_steps: |
|
|
return gr.Slider( |
|
|
minimum=0, |
|
|
maximum=10000, |
|
|
value=0, |
|
|
step=100, |
|
|
label="Browse Training Steps", |
|
|
info="No training steps available", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
max_step = max(available_steps) |
|
|
min_step = min(available_steps) |
|
|
|
|
|
if len(available_steps) > 1: |
|
|
step_size = available_steps[1] - available_steps[0] |
|
|
else: |
|
|
step_size = 100 |
|
|
|
|
|
return gr.Slider( |
|
|
minimum=min_step, |
|
|
maximum=max_step, |
|
|
value=max_step, |
|
|
step=step_size, |
|
|
label="Browse Training Steps", |
|
|
info=f"Browse results from steps {min_step}-{max_step} (interactive)", |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Image-GS: 2D Gaussian Splatting", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# Image-GS: Content-Adaptive Image Representation via 2D Gaussians |
|
|
|
|
|
Upload an image and configure parameters to train a 2D Gaussian Splatting representation. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("## Configuration") |
|
|
|
|
|
|
|
|
image_input = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
height=300, |
|
|
sources=["upload"], |
|
|
show_label=True, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### Basic Parameters") |
|
|
exp_name = gr.Textbox( |
|
|
label="Experiment Name", |
|
|
value="gradio_experiment", |
|
|
info="Name for this training run", |
|
|
) |
|
|
num_gaussians = gr.Slider( |
|
|
minimum=100, |
|
|
maximum=50000, |
|
|
value=10000, |
|
|
step=1000, |
|
|
label="Number of Gaussians", |
|
|
info="Number of Gaussians (for compression rate control). More = higher quality but slower training", |
|
|
) |
|
|
max_steps = gr.Slider( |
|
|
minimum=100, |
|
|
maximum=20000, |
|
|
value=10000, |
|
|
step=100, |
|
|
label="Maximum Training Steps", |
|
|
info="Maximum number of optimization steps. Default: 10000", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### Quantization") |
|
|
quantize = gr.Checkbox( |
|
|
label="Enable Quantization", |
|
|
value=False, |
|
|
info="Enable bit precision control of Gaussian parameters. Reduces memory usage.", |
|
|
) |
|
|
with gr.Row(): |
|
|
pos_bits = gr.Slider( |
|
|
4, |
|
|
32, |
|
|
16, |
|
|
step=1, |
|
|
label="Position Bits", |
|
|
info="Bit precision of individual coordinate dimension", |
|
|
) |
|
|
scale_bits = gr.Slider( |
|
|
4, |
|
|
32, |
|
|
16, |
|
|
step=1, |
|
|
label="Scale Bits", |
|
|
info="Bit precision of individual scale dimension", |
|
|
) |
|
|
with gr.Row(): |
|
|
rot_bits = gr.Slider( |
|
|
4, |
|
|
32, |
|
|
16, |
|
|
step=1, |
|
|
label="Rotation Bits", |
|
|
info="Bit precision of Gaussian orientation angle", |
|
|
) |
|
|
feat_bits = gr.Slider( |
|
|
4, |
|
|
32, |
|
|
16, |
|
|
step=1, |
|
|
label="Feature Bits", |
|
|
info="Bit precision of individual feature dimension", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### Initialization") |
|
|
init_mode = gr.Radio( |
|
|
choices=["gradient", "saliency", "random"], |
|
|
value="saliency", |
|
|
label="Initialization Mode", |
|
|
info="Gaussian position initialization mode. Gradient uses image gradients, saliency uses attention maps.", |
|
|
) |
|
|
init_random_ratio = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.3, |
|
|
step=0.1, |
|
|
label="Random Ratio", |
|
|
info="Ratio of Gaussians with randomly initialized position (default: 0.3)", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Parameters", open=False): |
|
|
|
|
|
gr.Markdown("#### Loss Weights") |
|
|
with gr.Row(): |
|
|
l1_loss_ratio = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="L1 Loss") |
|
|
l2_loss_ratio = gr.Slider(0.0, 2.0, 0.0, step=0.1, label="L2 Loss") |
|
|
ssim_loss_ratio = gr.Slider( |
|
|
0.0, 1.0, 0.1, step=0.01, label="SSIM Loss" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("#### Learning Rates") |
|
|
with gr.Row(): |
|
|
pos_lr = gr.Number(value=5e-4, label="Position LR", precision=6) |
|
|
scale_lr = gr.Number(value=2e-3, label="Scale LR", precision=6) |
|
|
with gr.Row(): |
|
|
rot_lr = gr.Number(value=2e-3, label="Rotation LR", precision=6) |
|
|
feat_lr = gr.Number(value=5e-3, label="Feature LR", precision=6) |
|
|
|
|
|
|
|
|
gr.Markdown("#### Optimization") |
|
|
disable_lr_schedule = gr.Checkbox( |
|
|
label="Disable LR Schedule", |
|
|
value=False, |
|
|
info="Keep learning rate constant", |
|
|
) |
|
|
disable_prog_optim = gr.Checkbox( |
|
|
label="Disable Progressive Optimization", |
|
|
value=False, |
|
|
info="Don't add Gaussians during training", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### Visualization") |
|
|
vis_gaussians = gr.Checkbox( |
|
|
label="Visualize Gaussians", |
|
|
value=True, |
|
|
info="Visualize Gaussians during optimization (default: True)", |
|
|
) |
|
|
save_image_steps = gr.Slider( |
|
|
minimum=200, |
|
|
maximum=10000, |
|
|
value=200, |
|
|
step=100, |
|
|
label="Save Image Every N Steps", |
|
|
info="Frequency of rendering intermediate results during optimization (default: 100)", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
start_btn = gr.Button("Start Training", variant="primary", size="lg") |
|
|
stop_btn = gr.Button("Stop Training", variant="stop", size="lg") |
|
|
|
|
|
status_text = gr.Textbox(label="Status", interactive=False, lines=2) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("## Training Progress") |
|
|
|
|
|
|
|
|
progress_logs = gr.Textbox( |
|
|
label="Training Logs", |
|
|
lines=10, |
|
|
max_lines=15, |
|
|
interactive=False, |
|
|
autoscroll=True, |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Initialization Map") |
|
|
initialization_map = gr.Image( |
|
|
label="Initialization Map", |
|
|
type="pil", |
|
|
height=200, |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Current Training Results") |
|
|
with gr.Row(): |
|
|
current_render = gr.Image( |
|
|
label="Current Render", |
|
|
type="pil", |
|
|
height=300, |
|
|
show_label=True, |
|
|
show_download_button=True, |
|
|
) |
|
|
current_gaussian_id = gr.Image( |
|
|
label="Gaussian ID", |
|
|
type="pil", |
|
|
height=300, |
|
|
show_label=True, |
|
|
show_download_button=True, |
|
|
) |
|
|
|
|
|
|
|
|
step_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=10000, |
|
|
value=0, |
|
|
step=100, |
|
|
label="Browse Training Steps", |
|
|
info="Slide to view results from different training steps (disabled during training)", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
gr.Markdown("## Final Results") |
|
|
with gr.Row(): |
|
|
final_render = gr.Image(label="Final Render", type="pil", height=300) |
|
|
final_checkpoint = gr.File(label="Download Final Checkpoint (.pt)") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
results_btn = gr.Button("Load Final Results", size="lg") |
|
|
enable_slider_btn = gr.Button( |
|
|
"Enable Step Browsing", size="lg", variant="secondary" |
|
|
) |
|
|
|
|
|
|
|
|
start_btn.click( |
|
|
fn=start_training_and_stream, |
|
|
inputs=[ |
|
|
image_input, |
|
|
exp_name, |
|
|
num_gaussians, |
|
|
quantize, |
|
|
pos_bits, |
|
|
scale_bits, |
|
|
rot_bits, |
|
|
feat_bits, |
|
|
init_mode, |
|
|
init_random_ratio, |
|
|
max_steps, |
|
|
vis_gaussians, |
|
|
save_image_steps, |
|
|
l1_loss_ratio, |
|
|
l2_loss_ratio, |
|
|
ssim_loss_ratio, |
|
|
pos_lr, |
|
|
scale_lr, |
|
|
rot_lr, |
|
|
feat_lr, |
|
|
disable_lr_schedule, |
|
|
disable_prog_optim, |
|
|
], |
|
|
outputs=[ |
|
|
status_text, |
|
|
progress_logs, |
|
|
initialization_map, |
|
|
current_render, |
|
|
current_gaussian_id, |
|
|
start_btn, |
|
|
stop_btn, |
|
|
], |
|
|
) |
|
|
|
|
|
stop_btn.click(fn=stop_training, outputs=status_text) |
|
|
|
|
|
results_btn.click(fn=get_final_results, outputs=[final_render, final_checkpoint]) |
|
|
|
|
|
enable_slider_btn.click(fn=update_step_slider_after_training, outputs=[step_slider]) |
|
|
|
|
|
step_slider.change( |
|
|
fn=browse_step_results, |
|
|
inputs=[step_slider], |
|
|
outputs=[current_render, current_gaussian_id], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|
|