depthsplat / src /test_model.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
######测试添加#######
from src.dataset.data_module import DataModule
import hydra
from omegaconf import DictConfig, OmegaConf
from src.config import load_typed_root_config
import warnings
import torch
from src.model.encoder.encoder_depthsplat_revise import EncoderDepthSplat_test
from src.model.encoder.encoder_depthsplat import EncoderDepthSplat
from src.config import load_typed_root_config
from src.dataset.data_module import DataModule
from src.global_cfg import set_cfg
from src.loss import get_losses
from src.misc.LocalLogger import LocalLogger
from src.misc.step_tracker import StepTracker
from src.misc.wandb_tools import update_checkpoint_path
from src.misc.resume_ckpt import find_latest_ckpt
from src.model.decoder import get_decoder
from src.model.encoder import get_encoder
from src.model.model_wrapper import ModelWrapper
def generate_mock_batch(batch_size=1, context_views=6, target_views=8, image_size=(256, 448)):
"""
生成模拟的 batch 数据
参数:
batch_size: 批次大小,默认1
context_views: context 中的视图数量,默认6
target_views: target 中的视图数量,默认8
image_size: 图像尺寸 (高度, 宽度),默认 (270, 480)
返回:
符合指定结构的 batch 字典
"""
h, w = image_size
# 创建 context 部分
context = {
"extrinsics": torch.randn(batch_size, context_views, 4, 4, dtype=torch.float32),
"intrinsics": torch.randn(batch_size, context_views, 3, 3, dtype=torch.float32),
"image": torch.randn(batch_size, context_views, 3, h, w, dtype=torch.float32),
"near": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 1.0,
"far": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 10.0,
"index": torch.randint(0, 100, (batch_size, context_views), dtype=torch.int64)
}
# 创建 target 部分
target = {
"extrinsics": torch.randn(batch_size, target_views, 4, 4, dtype=torch.float32),
"intrinsics": torch.randn(batch_size, target_views, 3, 3, dtype=torch.float32),
"image": torch.randn(batch_size, target_views, 3, h, w, dtype=torch.float32),
"near": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 1.0,
"far": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 10.0,
"index": torch.randint(0, 100, (batch_size, target_views), dtype=torch.int64)
}
# 创建完整的 batch 结构
batch = {
"context": context,
"target": target
}
return batch
@hydra.main(
version_base=None,
config_path="/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/config",
config_name="main",
)
def test(cfg_dict: DictConfig):
scene_path = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/datasets/dl3dv/train/000000.torch"
# 加载数据
scene_data = torch.load(scene_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
#加载设置
cfg = load_typed_root_config(cfg_dict)
encoder = EncoderDepthSplat_test(cfg.model.encoder).to(device) #定义模型
# encoder = EncoderDepthSplat(cfg.model.encoder).to(device)
batch = generate_mock_batch()
# 将batch数据移动到GPU
def move_to_device(data):
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: move_to_device(v) for k, v in data.items()}
elif isinstance(data, list):
return [move_to_device(v) for v in data]
else:
return data
batch = move_to_device(batch)
gs = encoder(batch["context"], 1, False)
print("test_over")
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
test()