| | from tabnanny import verbose |
| | import torch |
| | import math |
| | from audiocraft.models import MusicGen |
| | import numpy as np |
| | from PIL import Image, ImageDraw, ImageFont, ImageColor |
| | import string |
| | import tempfile |
| | import os |
| | import textwrap |
| | import requests |
| | from io import BytesIO |
| | from huggingface_hub import hf_hub_download |
| | import librosa |
| |
|
| |
|
| | INTERRUPTING = False |
| |
|
| | def separate_audio_segments(audio, segment_duration=30, overlap=1): |
| | sr, audio_data = audio[0], audio[1] |
| | |
| | segment_samples = sr * segment_duration |
| | total_samples = max(min((len(audio_data) // segment_samples), 25), 0) |
| | overlap_samples = sr * overlap |
| |
|
| | segments = [] |
| | start_sample = 0 |
| | |
| | if total_samples == 0: |
| | total_samples = 1 |
| | segment_samples = len(audio_data) |
| | overlap_samples = 0 |
| | while total_samples >= segment_samples: |
| | |
| | |
| | |
| | end_sample = start_sample + segment_samples |
| | segment = audio_data[start_sample:end_sample] |
| | segments.append((sr, segment)) |
| |
|
| | start_sample += segment_samples - overlap_samples |
| | total_samples -= segment_samples |
| |
|
| | |
| | if total_samples > 0: |
| | segment = audio_data[-segment_samples:] |
| | segments.append((sr, segment)) |
| | print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds") |
| | return segments |
| |
|
| | def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False): |
| | |
| | melody_segments = separate_audio_segments(melody, segment_duration, 0) |
| | |
| | |
| | melodys = [] |
| | output_segments = [] |
| | last_chunk = [] |
| | text += ", seed=" + str(seed) |
| | prompt_segment = None |
| | |
| | duration = min(duration, 720) |
| | overlap = min(overlap, 15) |
| | |
| | |
| | total_segments = max(math.ceil(duration / segment_duration),1) |
| | |
| | duration_loss = max(total_segments - 1,0) * math.ceil(overlap / 2) |
| | |
| | excess_duration = segment_duration - (total_segments * segment_duration - duration) |
| | print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}") |
| | duration += duration_loss |
| | while excess_duration + duration_loss > segment_duration: |
| | total_segments += 1 |
| | |
| | duration_loss += math.ceil(overlap / 2) |
| | |
| | excess_duration = segment_duration - (total_segments * segment_duration - duration) |
| | print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}") |
| | if excess_duration + duration_loss > segment_duration: |
| | duration += duration_loss |
| | duration_loss = 0 |
| | total_segments = min(total_segments, (720 // segment_duration)) |
| |
|
| | |
| | if len(melody_segments) < total_segments: |
| | |
| | for i in range(total_segments - len(melody_segments)): |
| | segment = melody_segments[i] |
| | melody_segments.append(segment) |
| | print(f"melody_segments: {len(melody_segments)} fixed") |
| |
|
| | |
| | for segment_idx in range(total_segments): |
| | if INTERRUPTING: |
| | return [], duration |
| | print(f"segment {segment_idx + 1} of {total_segments} \r") |
| |
|
| | if harmony_only: |
| | |
| | |
| | verse_harmonic, verse_percussive = librosa.effects.hpss(melody_segments[segment_idx][1]) |
| | |
| | |
| | |
| | sr, verse = melody_segments[segment_idx][0], torch.from_numpy(verse_harmonic).to(MODEL.device).float().t().unsqueeze(0) |
| | else: |
| | sr, verse = melody_segments[segment_idx][0], torch.from_numpy(melody_segments[segment_idx][1]).to(MODEL.device).float().t().unsqueeze(0) |
| |
|
| | print(f"shape:{verse.shape} dim:{verse.dim()}") |
| | if verse.dim() == 2: |
| | verse = verse[None] |
| | verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)] |
| |
|
| | |
| | melodys.append(verse) |
| |
|
| | torch.manual_seed(seed) |
| |
|
| | |
| | |
| | prompt_verse = melodys[0] |
| | if prompt_index > 0: |
| | |
| | prompt_verse = melodys[prompt_index if prompt_index <= (total_segments - 1) else (total_segments -1)] |
| | |
| | |
| | MODEL.set_generation_params( |
| | use_sampling=True, |
| | top_k=MODEL.generation_params["top_k"], |
| | top_p=MODEL.generation_params["top_p"], |
| | temperature=MODEL.generation_params["temp"], |
| | cfg_coef=MODEL.generation_params["cfg_coef"], |
| | duration=segment_duration, |
| | two_step_cfg=False, |
| | rep_penalty=0.5 |
| | ) |
| | |
| | print(f"Generating New Prompt Segment: {text} from verse {prompt_index}\r") |
| | prompt_segment = MODEL.generate_with_all( |
| | descriptions=[text], |
| | melody_wavs=prompt_verse, |
| | sample_rate=sr, |
| | progress=False, |
| | prompt=None, |
| | ) |
| |
|
| | for idx, verse in enumerate(melodys): |
| | if INTERRUPTING: |
| | return output_segments, duration |
| |
|
| | print(f'Segment duration: {segment_duration}, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss}') |
| | |
| | if ((idx + 1) == len(melodys)) or (duration < segment_duration): |
| | mod_duration = max(min(duration, segment_duration),1) |
| | print(f'Modify verse length, duration: {duration}, overlap: {overlap} Overlap Loss: {duration_loss} to mod duration: {mod_duration}') |
| | MODEL.set_generation_params( |
| | use_sampling=True, |
| | top_k=MODEL.generation_params["top_k"], |
| | top_p=MODEL.generation_params["top_p"], |
| | temperature=MODEL.generation_params["temp"], |
| | cfg_coef=MODEL.generation_params["cfg_coef"], |
| | duration=mod_duration, |
| | two_step_cfg=False, |
| | rep_penalty=0.5 |
| | ) |
| | try: |
| | |
| | verse = verse[:, :, -mod_duration*MODEL.sample_rate:] |
| | prompt_segment = prompt_segment[:, :, -mod_duration*MODEL.sample_rate:] |
| | except: |
| | |
| | verse = verse[:, :, :mod_duration*MODEL.sample_rate] |
| | prompt_segment = prompt_segment[:, :, :mod_duration*MODEL.sample_rate] |
| | |
| | |
| | print(f"Generating New Melody Segment {idx + 1}: {text}\r") |
| | output = MODEL.generate_with_all( |
| | descriptions=[text], |
| | melody_wavs=verse, |
| | sample_rate=sr, |
| | progress=False, |
| | prompt=prompt_segment, |
| | ) |
| | |
| | |
| | if prompt_index < 0: |
| | prompt_segment = output |
| |
|
| | |
| | |
| | output_segments.append(output) |
| | print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}") |
| | |
| | if duration > segment_duration: |
| | duration -= segment_duration |
| | return output_segments, excess_duration |
| |
|
| | def save_image(image): |
| | """ |
| | Saves a PIL image to a temporary file and returns the file path. |
| | |
| | Parameters: |
| | - image: PIL.Image |
| | The PIL image object to be saved. |
| | |
| | Returns: |
| | - str or None: The file path where the image was saved, |
| | or None if there was an error saving the image. |
| | |
| | """ |
| | temp_dir = tempfile.gettempdir() |
| | temp_file = tempfile.NamedTemporaryFile(suffix=".png", dir=temp_dir, delete=False) |
| | temp_file.close() |
| | file_path = temp_file.name |
| |
|
| | try: |
| | image.save(file_path) |
| | |
| | except Exception as e: |
| | print("Unable to save image:", str(e)) |
| | return None |
| | finally: |
| | return file_path |
| |
|
| | def hex_to_rgba(hex_color): |
| | try: |
| | |
| | rgba = ImageColor.getcolor(hex_color, "RGBA") |
| | except ValueError: |
| | |
| | rgba = (255,255,0,255) |
| | return rgba |
| |
|
| | def load_font(font_name, font_size=16): |
| | """ |
| | Load a font using the provided font name and font size. |
| | |
| | Parameters: |
| | font_name (str): The name of the font to load. Can be a font name recognized by the system, a URL to download the font file, |
| | a local file path, or a Hugging Face model hub identifier. |
| | font_size (int, optional): The size of the font. Default is 16. |
| | |
| | Returns: |
| | ImageFont.FreeTypeFont: The loaded font object. |
| | |
| | Notes: |
| | This function attempts to load the font using various methods until a suitable font is found. If the provided font_name |
| | cannot be loaded, it falls back to a default font. |
| | |
| | The font_name can be one of the following: |
| | - A font name recognized by the system, which can be loaded using ImageFont.truetype. |
| | - A URL pointing to the font file, which is downloaded using requests and then loaded using ImageFont.truetype. |
| | - A local file path to the font file, which is loaded using ImageFont.truetype. |
| | - A Hugging Face model hub identifier, which downloads the font file from the Hugging Face model hub using hf_hub_download |
| | and then loads it using ImageFont.truetype. |
| | |
| | Example: |
| | font = load_font("Arial.ttf", font_size=20) |
| | """ |
| | font = None |
| | if not "http" in font_name: |
| | try: |
| | font = ImageFont.truetype(font_name, font_size) |
| | except (FileNotFoundError, OSError): |
| | print("Font not found. Using Hugging Face download..\n") |
| |
|
| | if font is None: |
| | try: |
| | font_path = ImageFont.truetype(hf_hub_download(repo_id=os.environ.get('SPACE_ID', ''), filename="assets/" + font_name, repo_type="space"), encoding="UTF-8") |
| | font = ImageFont.truetype(font_path, font_size) |
| | except (FileNotFoundError, OSError): |
| | print("Font not found. Trying to download from local assets folder...\n") |
| | if font is None: |
| | try: |
| | font = ImageFont.truetype("assets/" + font_name, font_size) |
| | except (FileNotFoundError, OSError): |
| | print("Font not found. Trying to download from URL...\n") |
| |
|
| | if font is None: |
| | try: |
| | req = requests.get(font_name) |
| | font = ImageFont.truetype(BytesIO(req.content), font_size) |
| | except (FileNotFoundError, OSError): |
| | print(f"Font not found: {font_name} Using default font\n") |
| | if font: |
| | print(f"Font loaded {font.getname()}") |
| | else: |
| | font = ImageFont.load_default() |
| | return font |
| |
|
| |
|
| | def add_settings_to_image(title: str = "title", description: str = "", width: int = 768, height: int = 512, background_path: str = "", font: str = "arial.ttf", font_color: str = "#ffffff"): |
| | |
| | image = Image.new("RGBA", (width, height), (255, 255, 255, 0)) |
| | |
| | if background_path == "": |
| | background = Image.new("RGBA", (width, height), (255, 255, 255, 255)) |
| | else: |
| | background = Image.open(background_path).convert("RGBA") |
| |
|
| | |
| | font_color = hex_to_rgba(font_color) |
| |
|
| | |
| | text_x = width // 2 |
| | text_y = height // 2 |
| | |
| | title_font = load_font(font, 26) |
| |
|
| | title_text = '\n'.join(textwrap.wrap(title, width // 12)) |
| | title_x, title_y, title_text_width, title_text_height = title_font.getbbox(title_text) |
| | title_x = max(text_x - (title_text_width // 2), title_x, 0) |
| | title_y = text_y - (height // 2) + 10 |
| | title_draw = ImageDraw.Draw(image) |
| | title_draw.multiline_text((title_x, title_y), title, fill=font_color, font=title_font, align="center") |
| | |
| | description_font = load_font(font, 16) |
| | description_text = '\n'.join(textwrap.wrap(description, width // 12)) |
| | description_x, description_y, description_text_width, description_text_height = description_font.getbbox(description_text) |
| | description_x = max(text_x - (description_text_width // 2), description_x, 0) |
| | description_y = title_y + title_text_height + 20 |
| | description_draw = ImageDraw.Draw(image) |
| | description_draw.multiline_text((description_x, description_y), description_text, fill=font_color, font=description_font, align="center") |
| | |
| | bg_w, bg_h = background.size |
| | offset = ((bg_w - width) // 2, (bg_h - height) // 2) |
| | |
| | background.paste(image, offset, mask=image) |
| |
|
| | |
| | return save_image(background) |