myklipers-models / talknet-asd /export_onnx_cpu.py
DimasMP3's picture
Upload folder using huggingface_hub
5c69097 verified
import os
import torch
from loss import lossAV
from model.talkNetModel import talkNetModel
class TalkNetCPU(torch.nn.Module):
"""CPU-only wrapper for TalkNet export."""
def __init__(self, ckpt_path: str):
super().__init__()
self.model = talkNetModel()
self.lossAV = lossAV()
self.ckpt_path = ckpt_path
def load_parameters(self) -> None:
"""Load state_dict saved by talkNet.saveParameters (handles module. prefix)."""
self_state = self.state_dict()
loaded_state = torch.load(self.ckpt_path, map_location="cpu")
for name, param in loaded_state.items():
orig_name = name
target_name = name
if target_name not in self_state:
target_name = target_name.replace("module.", "")
if target_name not in self_state:
print(f"{orig_name} is not in the model.")
continue
if self_state[target_name].shape != loaded_state[orig_name].shape:
print(
f"Shape mismatch {orig_name}: "
f"model {self_state[target_name].shape}, "
f"loaded {loaded_state[orig_name].shape}"
)
continue
self_state[target_name].copy_(param)
def forward(self, audio_mfcc: torch.Tensor, video_gray: torch.Tensor) -> torch.Tensor:
"""
audio_mfcc: (B, Ta, 13)
video_gray: (B, Tv, 224, 224)
returns logits: (B*, 2)
"""
audio_embed = self.model.forward_audio_frontend(audio_mfcc)
visual_embed = self.model.forward_visual_frontend(video_gray)
audio_embed, visual_embed = self.model.forward_cross_attention(
audio_embed, visual_embed
)
av_embed = self.model.forward_audio_visual_backend(audio_embed, visual_embed)
logits = self.lossAV.FC(av_embed)
return logits
def main() -> None:
ckpt_path = os.environ.get("CKPT_PATH", "model/pretrain_TalkSet.model")
out_path = os.environ.get("OUT_PATH", "talknet_asd_cpu.onnx")
model = TalkNetCPU(ckpt_path)
model.load_parameters()
model.eval()
# Dummy inputs only to build the graph; real lengths are dynamic via dynamic_axes.
dummy_audio = torch.randn(1, 100, 13) # ~1s MFCC (100 frames)
# Model expects 112x112 (demoTalkNet crops 224->center 112)
dummy_video = torch.randn(1, 25, 112, 112) # 25 frames of 112x112 gray crops
torch.onnx.export(
model,
(dummy_audio, dummy_video),
out_path,
input_names=["audio_mfcc", "video_gray"],
output_names=["logits"],
dynamic_axes={
"audio_mfcc": {0: "batch", 1: "time_audio"},
"video_gray": {0: "batch", 1: "time_video"},
"logits": {0: "time_any"},
},
opset_version=14,
)
print(f"Saved ONNX to {out_path}")
if __name__ == "__main__":
main()