imagic / modeling.py
juliagsy's picture
Update modeling.py
e489a9f verified
import numpy as np
import torch
import torch.nn as nn
from transformers import MusicgenForConditionalGeneration, AutoModel, PretrainedConfig, PreTrainedModel
class Im2Mu(nn.Module):
def __init__(self, embed_dims=768, seq_len=64):
super(Im2Mu, self).__init__()
self.musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
self.muvis = AutoModel.from_pretrained("juliagsy/muvis", trust_remote_code=True).model.vit
self.loss_ce = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=-100)
self.img_lin = nn.Linear(197, 256)
def shift_right(self, input_ids):
shifted_input_ids = torch.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = 0
return shifted_input_ids
def forward(self, img, wav):
img_e = self.muvis(**img)["last_hidden_state"]
img_embeds = self.musicgen.get_encoder()(
inputs_embeds=img_e
)["last_hidden_state"]
img_embeds = img_embeds.permute(0, 2, 1)
img_embeds = self.img_lin(img_embeds)
img_embeds = img_embeds.permute(0, 2, 1)
wav_tokens = self.musicgen.get_audio_encoder().encode(
**wav,
)["audio_codes"]
wav_size = wav_tokens.size()
wav_tokens = wav_tokens.view((wav_size[1] * wav_size[2], wav_size[-1]))
wav_tokens = self.shift_right(wav_tokens)
ret = self.musicgen(
decoder_input_ids=wav_tokens,
encoder_outputs=(img_embeds,),
)
loss = self.loss_ce(ret.logits.view(-1, self.musicgen.config.audio_encoder.codebook_size), wav_tokens.view(-1))
return loss
def generate(self, img, wav=None, guidance_scale=3, max_new_tokens=256, device="cpu"):
img_embeds = self.muvis(**img)["last_hidden_state"]
img_embeds = img_embeds.permute(0, 2, 1)
img_embeds = self.img_lin(img_embeds)
img_embeds = img_embeds.permute(0, 2, 1)
img_embeds = self.musicgen.get_encoder()(
inputs_embeds=img_embeds
)["last_hidden_state"]
if wav is not None:
input_ids = self.musicgen.get_audio_encoder().encode(
**wav,
)["audio_codes"]
wav_size = input_ids.size()
input_ids = input_ids.view((wav_size[1] * wav_size[2], wav_size[-1]))
input_ids = self.shift_right(input_ids)
ret = self.musicgen.generate(
decoder_input_ids=input_ids,
# decoder_attention_mask=decoder_attention_mask,
encoder_outputs=(img_embeds,),
do_sample=True,
guidance_scale=guidance_scale,
max_new_tokens=256,
)
else:
input_ids = torch.zeros((4, 1)).long().to(device)
decoder_attention_mask = torch.ones((img_embeds.size(0), 1)).long().to(device)
ret = self.musicgen.generate(
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=(img_embeds,),
do_sample=True,
guidance_scale=guidance_scale,
max_new_tokens=max_new_tokens,
)
return ret
class ImagicConfig(PretrainedConfig):
model_type = "imagic"
def __init__(
self,
embed_dims=768,
seq_len=64,
**kwargs,
):
self.embed_dims = embed_dims
self.seq_len = seq_len
super().__init__(**kwargs)
class ImagicModel(PreTrainedModel):
config_class = ImagicConfig
def __init__(self, config):
super().__init__(config)
self.model = Im2Mu(
embed_dims=config.embed_dims,
seq_len=config.seq_len,
)
def forward(self, img, wav):
return self.model.forward(img, wav)
def generate(self, img, wav=None, guidance_scale=3, device="cpu"):
return self.model.generate(img, wav=wav, guidance_scale=guidance_scale, device=device)