Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from modules.models import * | |
| from util import get_prompt_template | |
| from torchvision import transforms as vt | |
| import torchaudio | |
| from PIL import Image | |
| def greet(audio, image): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Get model | |
| model_conf_file = f'./config/model/ACL_ViT16.yaml' | |
| model = ACL(model_conf_file, device) | |
| model.train(False) | |
| model.load('./pretrain/Param_best.pth') | |
| # Get placeholder text | |
| prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() | |
| # Input pre processing | |
| sample_rate, audio = audio | |
| audio = audio.astype(np.float32, order='C') / 32768.0 | |
| desired_sample_rate = 16000 | |
| set_length = 10 | |
| audio_file = torch.from_numpy(audio) | |
| if desired_sample_rate != sample_rate: | |
| audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) | |
| if audio_file.shape[0] == 2: | |
| audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration) | |
| audio_file.squeeze(0) | |
| if audio_file.shape[0] > (desired_sample_rate * set_length): | |
| audio_file = audio_file[:desired_sample_rate * set_length] | |
| # zero padding | |
| if audio_file.shape[0] < (desired_sample_rate * set_length): | |
| pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] | |
| pad_val = torch.zeros(pad_len) | |
| audio_file = torch.cat((audio_file, pad_val), dim=0) | |
| audio_file = audio_file.unsqueeze(0) | |
| image_transform = vt.Compose([ | |
| vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), | |
| vt.ToTensor(), | |
| vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP | |
| ]) | |
| image_file = image_transform(image) | |
| # Inference | |
| placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) | |
| audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, | |
| prompt_length) | |
| # Localization result | |
| out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) | |
| seg = out_dict['heatmap'][j:j + 1] | |
| seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) | |
| seg_image = Image.fromarray(seg_image) | |
| heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) | |
| overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) | |
| return overlaid_image | |
| description = 'hello world' | |
| demo = gr.Interface( | |
| fn=greet, | |
| inputs=[gr.Image(type='pil'), gr.Audio()], | |
| outputs=gr.Image(type="pil"), | |
| title='AudioToken', | |
| description=description, | |
| ) | |
| demo.launch(debug=True) | |