import os import torch from loss import lossAV from model.talkNetModel import talkNetModel class TalkNetCPU(torch.nn.Module): """CPU-only wrapper for TalkNet export.""" def __init__(self, ckpt_path: str): super().__init__() self.model = talkNetModel() self.lossAV = lossAV() self.ckpt_path = ckpt_path def load_parameters(self) -> None: """Load state_dict saved by talkNet.saveParameters (handles module. prefix).""" self_state = self.state_dict() loaded_state = torch.load(self.ckpt_path, map_location="cpu") for name, param in loaded_state.items(): orig_name = name target_name = name if target_name not in self_state: target_name = target_name.replace("module.", "") if target_name not in self_state: print(f"{orig_name} is not in the model.") continue if self_state[target_name].shape != loaded_state[orig_name].shape: print( f"Shape mismatch {orig_name}: " f"model {self_state[target_name].shape}, " f"loaded {loaded_state[orig_name].shape}" ) continue self_state[target_name].copy_(param) def forward(self, audio_mfcc: torch.Tensor, video_gray: torch.Tensor) -> torch.Tensor: """ audio_mfcc: (B, Ta, 13) video_gray: (B, Tv, 224, 224) returns logits: (B*, 2) """ audio_embed = self.model.forward_audio_frontend(audio_mfcc) visual_embed = self.model.forward_visual_frontend(video_gray) audio_embed, visual_embed = self.model.forward_cross_attention( audio_embed, visual_embed ) av_embed = self.model.forward_audio_visual_backend(audio_embed, visual_embed) logits = self.lossAV.FC(av_embed) return logits def main() -> None: ckpt_path = os.environ.get("CKPT_PATH", "model/pretrain_TalkSet.model") out_path = os.environ.get("OUT_PATH", "talknet_asd_cpu.onnx") model = TalkNetCPU(ckpt_path) model.load_parameters() model.eval() # Dummy inputs only to build the graph; real lengths are dynamic via dynamic_axes. dummy_audio = torch.randn(1, 100, 13) # ~1s MFCC (100 frames) # Model expects 112x112 (demoTalkNet crops 224->center 112) dummy_video = torch.randn(1, 25, 112, 112) # 25 frames of 112x112 gray crops torch.onnx.export( model, (dummy_audio, dummy_video), out_path, input_names=["audio_mfcc", "video_gray"], output_names=["logits"], dynamic_axes={ "audio_mfcc": {0: "batch", 1: "time_audio"}, "video_gray": {0: "batch", 1: "time_video"}, "logits": {0: "time_any"}, }, opset_version=14, ) print(f"Saved ONNX to {out_path}") if __name__ == "__main__": main()