Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import httpx | |
| import json | |
| from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat | |
| #import subprocess | |
| import os | |
| import uuid | |
| from tempfile import gettempdir | |
| from PIL import Image | |
| import cv2 | |
| from pprint import pprint | |
| minilm = SentenceTransformer('all-MiniLM-L6-v2') | |
| mubert_tags_embeddings = get_mubert_tags_embeddings(minilm) | |
| # image_to_text = gr.Interface.load("spaces/doevent/image_to_text", api_key=os.environ['HF_TOKEN']) | |
| image_to_text = gr.Blocks.load(name="spaces/banana-dev/demo-clip-interrogator") | |
| def center_crop(img, dim: tuple = (512, 512)): | |
| """Returns center cropped image | |
| Args: | |
| img: image to be center cropped | |
| dim: dimensions (width, height) to be cropped | |
| """ | |
| width, height = img.shape[1], img.shape[0] | |
| # process crop width and height for max available dimension | |
| crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1] | |
| crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0] | |
| mid_x, mid_y = int(width/2), int(height/2) | |
| cw2, ch2 = int(crop_width/2), int(crop_height/2) | |
| crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2] | |
| return crop_img | |
| def scale_image(img, factor=1): | |
| """Returns resize image by scale factor. | |
| This helps to retain resolution ratio while resizing. | |
| Args: | |
| img: image to be scaled | |
| factor: scale factor to resize | |
| """ | |
| return cv2.resize(img,(int(img.shape[1]*factor), int(img.shape[0]*factor))) | |
| def get_track_by_tags(tags, pat, duration, maxit=20, loop=False): | |
| if loop: | |
| mode = "loop" | |
| else: | |
| mode = "track" | |
| r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', | |
| json={ | |
| "method": "RecordTrackTTM", | |
| "params": { | |
| "pat": pat, | |
| "duration": duration, | |
| "tags": tags, | |
| "mode": mode | |
| } | |
| }) | |
| pprint(r.text) | |
| rdata = json.loads(r.text) | |
| assert rdata['status'] == 1, rdata['error']['text'] | |
| trackurl = rdata['data']['tasks'][0]['download_link'] | |
| #print('Generating track ', end='') | |
| for i in range(maxit): | |
| r = httpx.get(trackurl) | |
| if r.status_code == 200: | |
| return trackurl | |
| time.sleep(1) | |
| def generate_track_by_prompt(image, email, duration, loop=False): | |
| try: | |
| # Checking Image Aspect Ratio | |
| filename_png = f"{uuid.uuid4().hex}.png" | |
| filepath_png = f"{gettempdir()}/{filename_png}" | |
| with Image.open(image) as im: | |
| # image size | |
| ratio_width = im.size[0] | |
| ratio_height = im.size[1] | |
| im.convert("RGB").save(filepath_png) | |
| if ratio_width > 3501 or ratio_height > 3501: | |
| raise gr.Error("Image aspect ratio must not exceed width: 1024 px or height: 1024 px.") | |
| elif ratio_width > 3500 or ratio_height > 3500: | |
| image_g = cv2.imread(image) | |
| scale_img = scale_image(image_g, factor=0.2) | |
| cv2.imwrite(filepath_png, scale_img) | |
| elif ratio_width > 1800 or ratio_height > 1800: | |
| image_g = cv2.imread(image) | |
| scale_img = scale_image(image_g, factor=0.3) | |
| cv2.imwrite(filepath_png, scale_img) | |
| elif ratio_width > 900 or ratio_height > 900: | |
| image_g = cv2.imread(image) | |
| scale_img = scale_image(image_g, factor=0.5) | |
| cv2.imwrite(filepath_png, scale_img) | |
| # prompt = image_to_text(filepath_png, "Image Captioning", "", "Nucleus sampling") | |
| prompt = image_to_text(filepath_png, "ViT-L (best for Stable Diffusion 1.*)", "Fast", fn_index=1)[0] | |
| print(f"PROMPT: {prompt}") | |
| pat = get_pat(email) | |
| _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0] | |
| filepath = get_track_by_tags(tags, pat, int(duration), loop=loop) | |
| filename_mp3 = filepath.split("/")[-1] | |
| filepath_mp3 = f"{gettempdir()}/{filename_mp3}" | |
| filename_mp4 = f"{uuid.uuid4().hex}.mp4" | |
| filepath_mp4 = f"{gettempdir()}/{filename_mp4}" | |
| os.system(f"wget {filepath} -P {gettempdir()}") | |
| # waveform | |
| with Image.open(filepath_png) as im: | |
| width = im.size[0] | |
| height = im.size[1] | |
| print(f"{width}x{height}") | |
| command = f'ffmpeg -hide_banner -loglevel warning -y -i {filepath_mp3} -loop 1 -i {filepath_png} -filter_complex "[0:a]showwaves=s={width}x{height}:colors=0xffffff:mode=cline,format=rgba[v];[1:v][v]overlay[outv]" -map "[outv]" -map 0:a -c:v libx264 -r 15 -c:a copy -pix_fmt yuv420p -shortest {filepath_mp4}' | |
| os.system(command) | |
| os.remove(filepath_png) | |
| os.remove(filepath_mp3) | |
| return filepath_mp4, filepath, prompt, tags | |
| except Exception as e: | |
| raise gr.Error(str(e)) | |
| iface = gr.Interface(fn=generate_track_by_prompt, | |
| inputs=[gr.Image(type="filepath"), | |
| "text", | |
| gr.Slider(label="duration (seconds)", value=30, minimum=10, maximum=60)], | |
| outputs=[gr.Video(label="Video"), | |
| gr.Audio(label="Audio"), | |
| gr.Text(label="Prompt"), | |
| gr.Text(label="Tags")]) | |
| iface.queue().launch() | |