File size: 3,039 Bytes
5c69097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os

import torch

from loss import lossAV
from model.talkNetModel import talkNetModel


class TalkNetCPU(torch.nn.Module):
    """CPU-only wrapper for TalkNet export."""

    def __init__(self, ckpt_path: str):
        super().__init__()
        self.model = talkNetModel()
        self.lossAV = lossAV()
        self.ckpt_path = ckpt_path

    def load_parameters(self) -> None:
        """Load state_dict saved by talkNet.saveParameters (handles module. prefix)."""
        self_state = self.state_dict()
        loaded_state = torch.load(self.ckpt_path, map_location="cpu")

        for name, param in loaded_state.items():
            orig_name = name
            target_name = name
            if target_name not in self_state:
                target_name = target_name.replace("module.", "")
            if target_name not in self_state:
                print(f"{orig_name} is not in the model.")
                continue
            if self_state[target_name].shape != loaded_state[orig_name].shape:
                print(
                    f"Shape mismatch {orig_name}: "
                    f"model {self_state[target_name].shape}, "
                    f"loaded {loaded_state[orig_name].shape}"
                )
                continue
            self_state[target_name].copy_(param)

    def forward(self, audio_mfcc: torch.Tensor, video_gray: torch.Tensor) -> torch.Tensor:
        """

        audio_mfcc: (B, Ta, 13)

        video_gray: (B, Tv, 224, 224)

        returns logits: (B*, 2)

        """
        audio_embed = self.model.forward_audio_frontend(audio_mfcc)
        visual_embed = self.model.forward_visual_frontend(video_gray)
        audio_embed, visual_embed = self.model.forward_cross_attention(
            audio_embed, visual_embed
        )
        av_embed = self.model.forward_audio_visual_backend(audio_embed, visual_embed)
        logits = self.lossAV.FC(av_embed)
        return logits


def main() -> None:
    ckpt_path = os.environ.get("CKPT_PATH", "model/pretrain_TalkSet.model")
    out_path = os.environ.get("OUT_PATH", "talknet_asd_cpu.onnx")

    model = TalkNetCPU(ckpt_path)
    model.load_parameters()
    model.eval()

    # Dummy inputs only to build the graph; real lengths are dynamic via dynamic_axes.
    dummy_audio = torch.randn(1, 100, 13)  # ~1s MFCC (100 frames)
    # Model expects 112x112 (demoTalkNet crops 224->center 112)
    dummy_video = torch.randn(1, 25, 112, 112)  # 25 frames of 112x112 gray crops

    torch.onnx.export(
        model,
        (dummy_audio, dummy_video),
        out_path,
        input_names=["audio_mfcc", "video_gray"],
        output_names=["logits"],
        dynamic_axes={
            "audio_mfcc": {0: "batch", 1: "time_audio"},
            "video_gray": {0: "batch", 1: "time_video"},
            "logits": {0: "time_any"},
        },
        opset_version=14,
    )
    print(f"Saved ONNX to {out_path}")


if __name__ == "__main__":
    main()