|
|
|
|
|
from src.dataset.data_module import DataModule |
|
|
import hydra |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from src.config import load_typed_root_config |
|
|
import warnings |
|
|
import torch |
|
|
from src.model.encoder.encoder_depthsplat_revise import EncoderDepthSplat_test |
|
|
from src.model.encoder.encoder_depthsplat import EncoderDepthSplat |
|
|
|
|
|
|
|
|
from src.config import load_typed_root_config |
|
|
from src.dataset.data_module import DataModule |
|
|
from src.global_cfg import set_cfg |
|
|
from src.loss import get_losses |
|
|
from src.misc.LocalLogger import LocalLogger |
|
|
from src.misc.step_tracker import StepTracker |
|
|
from src.misc.wandb_tools import update_checkpoint_path |
|
|
from src.misc.resume_ckpt import find_latest_ckpt |
|
|
from src.model.decoder import get_decoder |
|
|
from src.model.encoder import get_encoder |
|
|
from src.model.model_wrapper import ModelWrapper |
|
|
|
|
|
|
|
|
|
|
|
def generate_mock_batch(batch_size=1, context_views=6, target_views=8, image_size=(256, 448)): |
|
|
""" |
|
|
生成模拟的 batch 数据 |
|
|
|
|
|
参数: |
|
|
batch_size: 批次大小,默认1 |
|
|
context_views: context 中的视图数量,默认6 |
|
|
target_views: target 中的视图数量,默认8 |
|
|
image_size: 图像尺寸 (高度, 宽度),默认 (270, 480) |
|
|
|
|
|
返回: |
|
|
符合指定结构的 batch 字典 |
|
|
""" |
|
|
h, w = image_size |
|
|
|
|
|
|
|
|
context = { |
|
|
"extrinsics": torch.randn(batch_size, context_views, 4, 4, dtype=torch.float32), |
|
|
"intrinsics": torch.randn(batch_size, context_views, 3, 3, dtype=torch.float32), |
|
|
"image": torch.randn(batch_size, context_views, 3, h, w, dtype=torch.float32), |
|
|
"near": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 1.0, |
|
|
"far": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 10.0, |
|
|
"index": torch.randint(0, 100, (batch_size, context_views), dtype=torch.int64) |
|
|
} |
|
|
|
|
|
|
|
|
target = { |
|
|
"extrinsics": torch.randn(batch_size, target_views, 4, 4, dtype=torch.float32), |
|
|
"intrinsics": torch.randn(batch_size, target_views, 3, 3, dtype=torch.float32), |
|
|
"image": torch.randn(batch_size, target_views, 3, h, w, dtype=torch.float32), |
|
|
"near": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 1.0, |
|
|
"far": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 10.0, |
|
|
"index": torch.randint(0, 100, (batch_size, target_views), dtype=torch.int64) |
|
|
} |
|
|
|
|
|
|
|
|
batch = { |
|
|
"context": context, |
|
|
"target": target |
|
|
} |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
@hydra.main( |
|
|
version_base=None, |
|
|
config_path="/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/config", |
|
|
config_name="main", |
|
|
) |
|
|
def test(cfg_dict: DictConfig): |
|
|
|
|
|
scene_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/datasets/dl3dv/train/000000.torch" |
|
|
|
|
|
scene_data = torch.load(scene_path) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
cfg = load_typed_root_config(cfg_dict) |
|
|
|
|
|
encoder = EncoderDepthSplat_test(cfg.model.encoder).to(device) |
|
|
|
|
|
|
|
|
batch = generate_mock_batch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def move_to_device(data): |
|
|
if isinstance(data, torch.Tensor): |
|
|
return data.to(device) |
|
|
elif isinstance(data, dict): |
|
|
return {k: move_to_device(v) for k, v in data.items()} |
|
|
elif isinstance(data, list): |
|
|
return [move_to_device(v) for v in data] |
|
|
else: |
|
|
return data |
|
|
batch = move_to_device(batch) |
|
|
|
|
|
gs = encoder(batch["context"], 1, False) |
|
|
|
|
|
print("test_over") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
warnings.filterwarnings("ignore") |
|
|
torch.set_float32_matmul_precision('high') |
|
|
test() |