{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ePWjo4hLkSZh" }, "source": [ "# Orpheus Auto-Continuation Generator Notebook (ver. 3.0)\n", "\n", "***\n", "\n", "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", "\n", "***\n", "\n", "#### Project Los Angeles\n", "\n", "#### Tegridy Code 2026\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": { "id": "y1H5U8iiAIgD" }, "source": [ "# Setup Environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "8Dt7FYceaCKF", "scrolled": true }, "outputs": [], "source": [ "# Install all dependencies (run only once per session)\n", "\n", "!git clone https://github.com/asigalov61/tegridy-tools\n", "!pip install tqdm\n", "!pip install ipywidgets\n", "\n", "!pip install einops\n", "!pip install einx\n", "!pip install scikit-learn\n", "!pip install torch-summary\n", "\n", "!pip install huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Modules" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Lqp3urZyaDAp", "scrolled": true }, "outputs": [], "source": [ "# Import all needed modules\n", "\n", "print('=' * 70)\n", "print('Loading needed modules. Please wait...')\n", "\n", "import os\n", "\n", "os.environ[\"HF_XET_HIGH_PERFORMANCE\"] = \"1\"\n", "\n", "from tqdm import tqdm\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "print('=' * 70)\n", "print('Loading TMIDIX module...')\n", "\n", "%cd ~/tegridy-tools/tegridy-tools/\n", "\n", "import TMIDIX\n", "\n", "%cd ~/tegridy-tools/tegridy-tools/X-Transformer/\n", "\n", "from x_transformer_2_3_1 import TransformerWrapper, Decoder, AutoregressiveWrapper, top_p\n", "from x_transformer_2_3_1 import build_cls_model, cls_predict\n", "\n", "%cd ~\n", "\n", "import torch\n", "from torch.amp import autocast\n", "\n", "from torchsummary import summary\n", "\n", "from huggingface_hub import hf_hub_download\n", "\n", "print('=' * 70)\n", "print('Done!')\n", "print('Enjoy! :)')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": { "id": "PcEkAnhyAIgL" }, "source": [ "# Download and Init Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Orpheus Classifier Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n", " filename='Orpheus_Music_Transformer_Classifier_Trained_Model_23670_steps_0.1837_loss_0.9207_acc.pth',\n", " local_dir='./Models/',\n", " )\n", "\n", "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n", " filename='Orpheus_Music_Transformer_Classifier_Trained_Model_6698_steps_0.0894_loss_0.9654_acc.pth',\n", " local_dir='./Models/',\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Init Orpheus Classifier Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_to_load = 'Original' # Or \"Alternative\" (more efficient, trained on a larger training corpus)\n", "\n", "print('=' * 70)\n", "print('Building model...')\n", "if model_to_load == 'Original': \n", " \n", " cls_model = build_cls_model()\n", " \n", " full_path_to_trained_model = \"./Models/Orpheus_Music_Transformer_Classifier_Trained_Model_23670_steps_0.1837_loss_0.9207_acc.pth\"\n", "\n", "else:\n", " cls_model = build_cls_model(use_cls_token=False,\n", " average_pool_embed=True,\n", " rotary_pos_emb=True\n", " )\n", " \n", " full_path_to_trained_model = \"./Models/Orpheus_Music_Transformer_Classifier_Trained_Model_6698_steps_0.0894_loss_0.9654_acc.pth\"\n", "\n", "\n", "print('=' * 70)\n", "print('Loading model...')\n", "\n", "cls_model.load_state_dict(torch.load(full_path_to_trained_model))\n", "cls_model.cuda()\n", "cls_model.eval()\n", "\n", "summary(cls_model)\n", "\n", "print('=' * 70)\n", "print('Done!')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Orpheus Large Base Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n", " filename='Orpheus_Music_Transformer_Large_Trained_Model_43860_steps_0.6682_loss_0.8054_acc.pth',\n", " local_dir='./Models/',\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Init Orpheus Large Base Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SEQ_LEN = 8192\n", "PAD_IDX = 18819\n", "\n", "model = TransformerWrapper(\n", " num_tokens = PAD_IDX+1,\n", " max_seq_len = SEQ_LEN,\n", " attn_layers = Decoder(dim = 2048,\n", " depth = 16,\n", " heads = 16,\n", " rotary_pos_emb = True,\n", " attn_flash = True\n", " )\n", " )\n", "\n", "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n", "\n", "print('=' * 70)\n", "print('Loading model checkpoint...')\n", "\n", "model_path = './Models/Orpheus_Music_Transformer_Large_Trained_Model_43860_steps_0.6682_loss_0.8054_acc.pth'\n", "\n", "model.load_state_dict(torch.load(model_path))\n", "\n", "print('=' * 70)\n", "\n", "model.cuda()\n", "model.eval()\n", "\n", "model = torch.compile(model)\n", "\n", "print('Done!')\n", "\n", "summary(model)\n", "\n", "dtype = torch.bfloat16\n", "\n", "ctx = autocast(device_type='cuda', dtype=dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load source MIDI" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "midi_file = './tegridy-tools/tegridy-tools/seed-intro.mid'\n", "\n", "print('=' * 70)\n", "print('Loading MIDI File:', midi_file)\n", "print('=' * 70)\n", "\n", "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n", "\n", "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)\n", "\n", "escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)\n", "\n", "escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes)\n", "\n", "escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, min_notes_gap=0)\n", "\n", "dscore = TMIDIX.delta_score_notes(escore_notes)\n", "\n", "dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])\n", "\n", "melody_chords = [18816]\n", "\n", "#=======================================================\n", "# MAIN PROCESSING CYCLE\n", "#=======================================================\n", "\n", "for i, c in enumerate(dcscore):\n", "\n", " # Delta start-times\n", " delta_time = c[0][0]\n", " melody_chords.append(delta_time)\n", "\n", " for e in c:\n", " \n", " #=======================================================\n", " \n", " # Durations\n", " dur = max(1, min(255, e[1]))\n", "\n", " # Patches\n", " pat = max(0, min(128, e[5]))\n", " \n", " # Pitches\n", " ptc = max(1, min(127, e[3]))\n", " \n", " # Velocities\n", " # Calculating octo-velocity\n", " \n", " vel = max(8, min(127, e[4]))\n", " velocity = round(vel / 15)-1\n", " \n", " #=======================================================\n", " # FINAL NOTE SEQ\n", " #=======================================================\n", " \n", " # Writing final note\n", " pat_ptc = (128 * pat) + ptc \n", " dur_vel = (8 * dur) + velocity\n", "\n", " melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816\n", "\n", "print('Done!')\n", "print('=' * 70)\n", "print('Composition has', len(melody_chords), 'tokens')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_gen_chunk = 4 # Number of continuation chunks to generate\n", "batch_size = 12 # More is better (max it out)\n", "temperature = 0.95\n", "top_p_value = 0.96\n", "\n", "#===================================================================\n", "\n", "song = melody_chords[:1024]\n", "\n", "torch.cuda.empty_cache()\n", "\n", "for i in tqdm(range(num_gen_chunk)):\n", "\n", " x = torch.LongTensor([song] * batch_size).cuda()\n", " \n", " with ctx:\n", " out = model.generate(x,\n", " 512,\n", " temperature=temperature,\n", " filter_logits_fn=top_p,\n", " filter_kwargs={'thres': top_p_value},\n", " return_prime=True,\n", " verbose=False)\n", " \n", " outs = out.tolist()\n", "\n", " outputs = []\n", "\n", " for o in outs:\n", " if 18818 not in o and 18817 not in o:\n", " times = [oo for oo in o if oo < 256]\n", " if all(True if oo < 128 else False for oo in times):\n", " outputs.append(o) \n", " \n", " cls_outputs = []\n", " \n", " for o in outputs:\n", " cls_outputs.append(o[-1024:])\n", " \n", " preds, probs = cls_predict(cls_model, cls_outputs)\n", "\n", " max_prob = max(probs)\n", "\n", " print(sorted(probs, reverse=True))\n", " \n", " best_idx = probs.index(max_prob)\n", " \n", " best_output = outputs[best_idx][-512:]\n", "\n", " song.extend(best_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convert to MIDI" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('Sample INTs', song[:15])\n", "\n", "if len(song) != 0:\n", "\n", " song_f = []\n", " \n", " time = 0\n", " dur = 1\n", " vel = 90\n", " pitch = 60\n", " channel = 0\n", " patch = 0\n", "\n", " patches = [-1] * 16\n", "\n", " channels = [0] * 16\n", " channels[9] = 1\n", "\n", " for ss in song:\n", "\n", " if 0 <= ss < 256:\n", "\n", " time += ss * 16\n", "\n", " if 256 <= ss < 16768:\n", "\n", " patch = (ss-256) // 128\n", "\n", " if patch < 128:\n", "\n", " if patch not in patches:\n", " if 0 in channels:\n", " cha = channels.index(0)\n", " channels[cha] = 1\n", " else:\n", " cha = 15\n", "\n", " patches[cha] = patch\n", " channel = patches.index(patch)\n", " else:\n", " channel = patches.index(patch)\n", "\n", " if patch == 128:\n", " channel = 9\n", "\n", " pitch = (ss-256) % 128\n", "\n", "\n", " if 16768 <= ss < 18816:\n", "\n", " dur = ((ss-16768) // 8) * 16\n", " vel = (((ss-16768) % 8)+1) * 15\n", "\n", " song_f.append(['note', time, dur, channel, pitch, vel, patch])\n", "\n", "patches = [0 if x==-1 else x for x in patches]\n", "\n", "output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)\n", "\n", "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,\n", " output_signature = 'Orpheus Music Transformer',\n", " output_file_name = './Orpheus-Music-Transformer-Composition',\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", "\n", "print('Done!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Congrats! You did it! :)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }