|
|
import data
|
|
|
import torch
|
|
|
from models import imagebind_model
|
|
|
from models.imagebind_model import ModalityType
|
|
|
import torch.nn as nn
|
|
|
from imagen_pytorch import ImagenTrainer
|
|
|
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
|
|
|
from extract.getim import load_image
|
|
|
import torch.optim as optim
|
|
|
import os
|
|
|
from torchvision import transforms
|
|
|
from image2vidimg import cobtwoten, cobtwoten256
|
|
|
import os
|
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
transform = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
])
|
|
|
unloader = transforms.ToPILImage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def imagebind_out(audio_paths,model):
|
|
|
|
|
|
inputs = {
|
|
|
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
|
|
|
}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
embeddings = model(inputs)
|
|
|
|
|
|
return embeddings
|
|
|
|
|
|
class encode_audio(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
self.link2=nn.Linear(1024,343)
|
|
|
|
|
|
|
|
|
def forward(self,embeddings):
|
|
|
l1=embeddings
|
|
|
l2=self.link2(embeddings)
|
|
|
|
|
|
l3=torch.matmul(l2.transpose(1,2),l1)
|
|
|
|
|
|
return torch.cat([l1,l3],dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
def getAllFiles(targetDir):
|
|
|
listFiles = os.listdir(targetDir)
|
|
|
return listFiles
|
|
|
|
|
|
|
|
|
unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device)
|
|
|
unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
imagen = ElucidatedImagen(
|
|
|
text_embed_dim=1024,
|
|
|
unets = (unet1, unet2),
|
|
|
image_sizes = (64, 128),
|
|
|
random_crop_sizes = (None, 64),
|
|
|
temporal_downsample_factor = (2, 1),
|
|
|
num_sample_steps = 10,
|
|
|
cond_drop_prob = 0.1,
|
|
|
sigma_min = 0.002,
|
|
|
sigma_max = (80, 160),
|
|
|
sigma_data = 0.5,
|
|
|
rho = 7,
|
|
|
P_mean = -1.2,
|
|
|
P_std = 1.2,
|
|
|
S_churn = 80,
|
|
|
S_tmin = 0.05,
|
|
|
S_tmax = 50,
|
|
|
S_noise = 1.003,
|
|
|
).to(device)
|
|
|
|
|
|
trainer = ImagenTrainer(imagen)
|
|
|
|
|
|
|
|
|
trainer = trainer.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
model_imageb = imagebind_model.imagebind_huge(pretrained=True)
|
|
|
model_imageb=model_imageb.to(device)
|
|
|
model_imageb.eval()
|
|
|
|
|
|
|
|
|
|
|
|
epo=31
|
|
|
p=1
|
|
|
files = getAllFiles("./extract/audio")
|
|
|
|
|
|
outloss=0
|
|
|
model1=(encode_audio()).to(device)
|
|
|
|
|
|
optimizer = optim.Adam(model1.parameters(), lr=1e-5,
|
|
|
betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)
|
|
|
|
|
|
model1.train()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
for k in range(epo):
|
|
|
for nm in range(0, len(files) + 1 - p, p):
|
|
|
|
|
|
file_ext0 = os.path.splitext(files[nm])
|
|
|
front0, ext0 = file_ext0
|
|
|
audio_pat=[]
|
|
|
audio_pat.append("./extract/audio/" + str(front0) + ".wav")
|
|
|
|
|
|
fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg")
|
|
|
|
|
|
|
|
|
for ni in range(1,p):
|
|
|
file_ext = os.path.splitext(files[nm+ni])
|
|
|
front, ext = file_ext
|
|
|
|
|
|
content = cobtwoten("./extract/image/" + str(front) + ".jpg")
|
|
|
fcontent = torch.cat((fcontent, content), -5)
|
|
|
audio_pat.append("./extract/audio/" + str(front) + ".wav")
|
|
|
|
|
|
|
|
|
imageb_out = imagebind_out(audio_pat,model_imageb)
|
|
|
fmusic = model1(imageb_out["audio"].unsqueeze(1))
|
|
|
|
|
|
|
|
|
|
|
|
fmusic=fmusic.to(device)
|
|
|
fcontent=fcontent.to(device)
|
|
|
loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p)
|
|
|
trainer.update(unet_number = 2)
|
|
|
optimizer.step()
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
print(loss)
|
|
|
outloss=outloss+loss
|
|
|
|
|
|
|
|
|
outloss=outloss
|
|
|
|
|
|
print("epoch"+str(k)+" "+" loss: "+str(outloss))
|
|
|
|
|
|
outloss=0
|
|
|
if k % 3 == 2:
|
|
|
torch.save(model1, "wlc.pt")
|
|
|
trainer.save('./checkpoint.pt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|