File size: 3,757 Bytes
845600c
 
 
 
 
 
 
 
 
 
 
e489a9f
845600c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import torch
import torch.nn as nn
from transformers import MusicgenForConditionalGeneration, AutoModel, PretrainedConfig, PreTrainedModel


class Im2Mu(nn.Module):
  def __init__(self, embed_dims=768, seq_len=64):
    super(Im2Mu, self).__init__()

    self.musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
    self.muvis = AutoModel.from_pretrained("juliagsy/muvis", trust_remote_code=True).model.vit

    self.loss_ce = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=-100)
    self.img_lin = nn.Linear(197, 256)


  def shift_right(self, input_ids):
    shifted_input_ids = torch.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = 0
    return shifted_input_ids


  def forward(self, img, wav):
    img_e = self.muvis(**img)["last_hidden_state"]
    img_embeds = self.musicgen.get_encoder()(
        inputs_embeds=img_e
    )["last_hidden_state"]
    img_embeds = img_embeds.permute(0, 2, 1)
    img_embeds = self.img_lin(img_embeds)
    img_embeds = img_embeds.permute(0, 2, 1)

    wav_tokens = self.musicgen.get_audio_encoder().encode(
        **wav,
    )["audio_codes"]
    wav_size = wav_tokens.size()
    wav_tokens = wav_tokens.view((wav_size[1] * wav_size[2], wav_size[-1]))
    wav_tokens = self.shift_right(wav_tokens)

    ret = self.musicgen(
        decoder_input_ids=wav_tokens,
        encoder_outputs=(img_embeds,),
    )
    loss = self.loss_ce(ret.logits.view(-1, self.musicgen.config.audio_encoder.codebook_size), wav_tokens.view(-1))
    return loss


  def generate(self, img, wav=None, guidance_scale=3, max_new_tokens=256, device="cpu"):
    img_embeds = self.muvis(**img)["last_hidden_state"]
    img_embeds = img_embeds.permute(0, 2, 1)
    img_embeds = self.img_lin(img_embeds)
    img_embeds = img_embeds.permute(0, 2, 1)

    img_embeds = self.musicgen.get_encoder()(
        inputs_embeds=img_embeds
    )["last_hidden_state"]

    if wav is not None:
      input_ids = self.musicgen.get_audio_encoder().encode(
        **wav,
      )["audio_codes"]
      wav_size = input_ids.size()
      input_ids = input_ids.view((wav_size[1] * wav_size[2], wav_size[-1]))
      input_ids = self.shift_right(input_ids)
      ret = self.musicgen.generate(
          decoder_input_ids=input_ids,
          # decoder_attention_mask=decoder_attention_mask,
          encoder_outputs=(img_embeds,),
          do_sample=True,
          guidance_scale=guidance_scale,
          max_new_tokens=256,
      )
    else:
      input_ids = torch.zeros((4, 1)).long().to(device)
      decoder_attention_mask = torch.ones((img_embeds.size(0), 1)).long().to(device)
      ret = self.musicgen.generate(
          decoder_input_ids=input_ids,
          decoder_attention_mask=decoder_attention_mask,
          encoder_outputs=(img_embeds,),
          do_sample=True,
          guidance_scale=guidance_scale,
          max_new_tokens=max_new_tokens,
      )
    return ret


class ImagicConfig(PretrainedConfig):
    model_type = "imagic"

    def __init__(
        self,
        embed_dims=768,
        seq_len=64,
        **kwargs,
    ):
      self.embed_dims = embed_dims
      self.seq_len = seq_len
      super().__init__(**kwargs)


class ImagicModel(PreTrainedModel):
    config_class = ImagicConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = Im2Mu(
            embed_dims=config.embed_dims,
            seq_len=config.seq_len,
        )

    def forward(self, img, wav):
        return self.model.forward(img, wav)

    def generate(self, img, wav=None, guidance_scale=3, device="cpu"):
        return self.model.generate(img, wav=wav, guidance_scale=guidance_scale, device=device)