######测试添加####### from src.dataset.data_module import DataModule import hydra from omegaconf import DictConfig, OmegaConf from src.config import load_typed_root_config import warnings import torch from src.model.encoder.encoder_depthsplat_revise import EncoderDepthSplat_test from src.model.encoder.encoder_depthsplat import EncoderDepthSplat from src.config import load_typed_root_config from src.dataset.data_module import DataModule from src.global_cfg import set_cfg from src.loss import get_losses from src.misc.LocalLogger import LocalLogger from src.misc.step_tracker import StepTracker from src.misc.wandb_tools import update_checkpoint_path from src.misc.resume_ckpt import find_latest_ckpt from src.model.decoder import get_decoder from src.model.encoder import get_encoder from src.model.model_wrapper import ModelWrapper def generate_mock_batch(batch_size=1, context_views=6, target_views=8, image_size=(256, 448)): """ 生成模拟的 batch 数据 参数: batch_size: 批次大小,默认1 context_views: context 中的视图数量,默认6 target_views: target 中的视图数量,默认8 image_size: 图像尺寸 (高度, 宽度),默认 (270, 480) 返回: 符合指定结构的 batch 字典 """ h, w = image_size # 创建 context 部分 context = { "extrinsics": torch.randn(batch_size, context_views, 4, 4, dtype=torch.float32), "intrinsics": torch.randn(batch_size, context_views, 3, 3, dtype=torch.float32), "image": torch.randn(batch_size, context_views, 3, h, w, dtype=torch.float32), "near": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 1.0, "far": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 10.0, "index": torch.randint(0, 100, (batch_size, context_views), dtype=torch.int64) } # 创建 target 部分 target = { "extrinsics": torch.randn(batch_size, target_views, 4, 4, dtype=torch.float32), "intrinsics": torch.randn(batch_size, target_views, 3, 3, dtype=torch.float32), "image": torch.randn(batch_size, target_views, 3, h, w, dtype=torch.float32), "near": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 1.0, "far": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 10.0, "index": torch.randint(0, 100, (batch_size, target_views), dtype=torch.int64) } # 创建完整的 batch 结构 batch = { "context": context, "target": target } return batch @hydra.main( version_base=None, config_path="/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/config", config_name="main", ) def test(cfg_dict: DictConfig): scene_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/datasets/dl3dv/train/000000.torch" # 加载数据 scene_data = torch.load(scene_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") #加载设置 cfg = load_typed_root_config(cfg_dict) encoder = EncoderDepthSplat_test(cfg.model.encoder).to(device) #定义模型 # encoder = EncoderDepthSplat(cfg.model.encoder).to(device) batch = generate_mock_batch() # 将batch数据移动到GPU def move_to_device(data): if isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, dict): return {k: move_to_device(v) for k, v in data.items()} elif isinstance(data, list): return [move_to_device(v) for v in data] else: return data batch = move_to_device(batch) gs = encoder(batch["context"], 1, False) print("test_over") if __name__ == "__main__": warnings.filterwarnings("ignore") torch.set_float32_matmul_precision('high') test()