File size: 3,956 Bytes
a6dd040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
######测试添加#######
from src.dataset.data_module import DataModule
import hydra
from omegaconf import DictConfig, OmegaConf
from src.config import load_typed_root_config
import warnings
import torch
from src.model.encoder.encoder_depthsplat_revise import EncoderDepthSplat_test 
from src.model.encoder.encoder_depthsplat import EncoderDepthSplat 


from src.config import load_typed_root_config
from src.dataset.data_module import DataModule
from src.global_cfg import set_cfg
from src.loss import get_losses
from src.misc.LocalLogger import LocalLogger
from src.misc.step_tracker import StepTracker
from src.misc.wandb_tools import update_checkpoint_path
from src.misc.resume_ckpt import find_latest_ckpt
from src.model.decoder import get_decoder
from src.model.encoder import get_encoder
from src.model.model_wrapper import ModelWrapper



def generate_mock_batch(batch_size=1, context_views=6, target_views=8, image_size=(256, 448)):
    """
    生成模拟的 batch 数据
    
    参数:
    batch_size: 批次大小,默认1
    context_views: context 中的视图数量,默认6
    target_views: target 中的视图数量,默认8
    image_size: 图像尺寸 (高度, 宽度),默认 (270, 480)
    
    返回:
    符合指定结构的 batch 字典
    """
    h, w = image_size
    
    # 创建 context 部分
    context = {
        "extrinsics": torch.randn(batch_size, context_views, 4, 4, dtype=torch.float32),
        "intrinsics": torch.randn(batch_size, context_views, 3, 3, dtype=torch.float32),
        "image": torch.randn(batch_size, context_views, 3, h, w, dtype=torch.float32),
        "near": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 1.0,
        "far": torch.abs(torch.randn(batch_size, context_views, dtype=torch.float32)) + 10.0,
        "index": torch.randint(0, 100, (batch_size, context_views), dtype=torch.int64)
    }
    
    # 创建 target 部分
    target = {
        "extrinsics": torch.randn(batch_size, target_views, 4, 4, dtype=torch.float32),
        "intrinsics": torch.randn(batch_size, target_views, 3, 3, dtype=torch.float32),
        "image": torch.randn(batch_size, target_views, 3, h, w, dtype=torch.float32),
        "near": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 1.0,
        "far": torch.abs(torch.randn(batch_size, target_views, dtype=torch.float32)) + 10.0,
        "index": torch.randint(0, 100, (batch_size, target_views), dtype=torch.int64)
    }
    
    # 创建完整的 batch 结构
    batch = {
        "context": context,
        "target": target
    }
    
    return batch


@hydra.main(
    version_base=None,
    config_path="/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/config",
    config_name="main",
)
def test(cfg_dict: DictConfig):
    
    scene_path  = "/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/datasets/dl3dv/train/000000.torch"
    # 加载数据
    scene_data = torch.load(scene_path)
    
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    #加载设置
    cfg = load_typed_root_config(cfg_dict)
    
    encoder = EncoderDepthSplat_test(cfg.model.encoder).to(device)  #定义模型
    # encoder = EncoderDepthSplat(cfg.model.encoder).to(device) 
    
    batch = generate_mock_batch()
    
    
    
    
    # 将batch数据移动到GPU
    def move_to_device(data):
        if isinstance(data, torch.Tensor):
            return data.to(device)
        elif isinstance(data, dict):
            return {k: move_to_device(v) for k, v in data.items()}
        elif isinstance(data, list):
            return [move_to_device(v) for v in data]
        else:
            return data
    batch = move_to_device(batch)
    
    gs = encoder(batch["context"], 1, False)
    
    print("test_over")
    

if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    torch.set_float32_matmul_precision('high')
    test()