| | import os
|
| |
|
| | import torch
|
| |
|
| | from loss import lossAV
|
| | from model.talkNetModel import talkNetModel
|
| |
|
| |
|
| | class TalkNetCPU(torch.nn.Module):
|
| | """CPU-only wrapper for TalkNet export."""
|
| |
|
| | def __init__(self, ckpt_path: str):
|
| | super().__init__()
|
| | self.model = talkNetModel()
|
| | self.lossAV = lossAV()
|
| | self.ckpt_path = ckpt_path
|
| |
|
| | def load_parameters(self) -> None:
|
| | """Load state_dict saved by talkNet.saveParameters (handles module. prefix)."""
|
| | self_state = self.state_dict()
|
| | loaded_state = torch.load(self.ckpt_path, map_location="cpu")
|
| |
|
| | for name, param in loaded_state.items():
|
| | orig_name = name
|
| | target_name = name
|
| | if target_name not in self_state:
|
| | target_name = target_name.replace("module.", "")
|
| | if target_name not in self_state:
|
| | print(f"{orig_name} is not in the model.")
|
| | continue
|
| | if self_state[target_name].shape != loaded_state[orig_name].shape:
|
| | print(
|
| | f"Shape mismatch {orig_name}: "
|
| | f"model {self_state[target_name].shape}, "
|
| | f"loaded {loaded_state[orig_name].shape}"
|
| | )
|
| | continue
|
| | self_state[target_name].copy_(param)
|
| |
|
| | def forward(self, audio_mfcc: torch.Tensor, video_gray: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | audio_mfcc: (B, Ta, 13)
|
| | video_gray: (B, Tv, 224, 224)
|
| | returns logits: (B*, 2)
|
| | """
|
| | audio_embed = self.model.forward_audio_frontend(audio_mfcc)
|
| | visual_embed = self.model.forward_visual_frontend(video_gray)
|
| | audio_embed, visual_embed = self.model.forward_cross_attention(
|
| | audio_embed, visual_embed
|
| | )
|
| | av_embed = self.model.forward_audio_visual_backend(audio_embed, visual_embed)
|
| | logits = self.lossAV.FC(av_embed)
|
| | return logits
|
| |
|
| |
|
| | def main() -> None:
|
| | ckpt_path = os.environ.get("CKPT_PATH", "model/pretrain_TalkSet.model")
|
| | out_path = os.environ.get("OUT_PATH", "talknet_asd_cpu.onnx")
|
| |
|
| | model = TalkNetCPU(ckpt_path)
|
| | model.load_parameters()
|
| | model.eval()
|
| |
|
| |
|
| | dummy_audio = torch.randn(1, 100, 13)
|
| |
|
| | dummy_video = torch.randn(1, 25, 112, 112)
|
| |
|
| | torch.onnx.export(
|
| | model,
|
| | (dummy_audio, dummy_video),
|
| | out_path,
|
| | input_names=["audio_mfcc", "video_gray"],
|
| | output_names=["logits"],
|
| | dynamic_axes={
|
| | "audio_mfcc": {0: "batch", 1: "time_audio"},
|
| | "video_gray": {0: "batch", 1: "time_video"},
|
| | "logits": {0: "time_any"},
|
| | },
|
| | opset_version=14,
|
| | )
|
| | print(f"Saved ONNX to {out_path}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|
| |
|