|
|
import gradio as gr |
|
|
import requests |
|
|
import json |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_NAME = "Salesforce/codegen-350M-mono" |
|
|
|
|
|
|
|
|
try: |
|
|
code_generator = pipeline( |
|
|
"text-generation", |
|
|
model=MODEL_NAME, |
|
|
tokenizer=MODEL_NAME, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
|
|
|
code_generator = pipeline( |
|
|
"text-generation", |
|
|
model=MODEL_NAME, |
|
|
tokenizer=MODEL_NAME, |
|
|
device_map=None |
|
|
) |
|
|
|
|
|
|
|
|
STRUDEL_EXAMPLES = """ |
|
|
// Basic drum pattern |
|
|
s("bd hh sn hh").gain(0.8) |
|
|
|
|
|
// Melodic pattern |
|
|
n("0 2 4 7").s("sine").octave(4).lpf(2000) |
|
|
|
|
|
// Techno beat |
|
|
stack( |
|
|
s("bd*4").gain(0.8), |
|
|
s("~ hh ~ hh").gain(0.6), |
|
|
n("0 ~ 3 ~").s("sawtooth").octave(2).lpf(800) |
|
|
) |
|
|
|
|
|
// Ambient pattern |
|
|
n("[0 2 4]/3").s("sine").octave(3).slow(4).room(0.9).gain(0.5) |
|
|
""" |
|
|
|
|
|
def generate_strudel_code(prompt, genre="general", complexity="simple", max_length=150, temperature=0.7): |
|
|
"""Generate Strudel code using CodeGen""" |
|
|
|
|
|
|
|
|
system_prompt = f"""// Strudel live coding language for music |
|
|
// Generate {complexity} {genre} music patterns |
|
|
// Examples: |
|
|
{STRUDEL_EXAMPLES} |
|
|
|
|
|
// Create: {prompt} |
|
|
// Code: |
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
outputs = code_generator( |
|
|
system_prompt, |
|
|
max_length=len(system_prompt.split()) + max_length, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
num_return_sequences=1, |
|
|
pad_token_id=code_generator.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = outputs[0]['generated_text'] |
|
|
|
|
|
|
|
|
strudel_code = generated_text[len(system_prompt):].strip() |
|
|
|
|
|
|
|
|
strudel_code = clean_strudel_code(strudel_code) |
|
|
|
|
|
return strudel_code |
|
|
|
|
|
except Exception as e: |
|
|
return f"// Error generating code: {str(e)}\n// Fallback pattern:\ns(\"bd hh sn hh\").gain(0.8)" |
|
|
|
|
|
def clean_strudel_code(code): |
|
|
"""Clean and format generated Strudel code""" |
|
|
lines = code.split('\n') |
|
|
cleaned_lines = [] |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
|
|
|
if line.startswith('//') and any(word in line.lower() for word in ['user', 'request', 'example', 'create']): |
|
|
break |
|
|
|
|
|
if line and (line.startswith('s(') or line.startswith('n(') or line.startswith('stack(') or |
|
|
line.startswith('$:') or line.startswith('all(') or line.endswith(')') or |
|
|
line.endswith(',') or 'gain(' in line or 'lpf(' in line): |
|
|
cleaned_lines.append(line) |
|
|
|
|
|
elif any(keyword in line for keyword in ['function', 'var ', 'let ', 'const ', 'import', 'export']): |
|
|
break |
|
|
|
|
|
if len(cleaned_lines) >= 8: |
|
|
break |
|
|
|
|
|
|
|
|
if not cleaned_lines: |
|
|
return generate_fallback_pattern(genre) |
|
|
|
|
|
return '\n'.join(cleaned_lines) |
|
|
|
|
|
def generate_fallback_pattern(genre): |
|
|
"""Generate a simple fallback pattern based on genre""" |
|
|
patterns = { |
|
|
"techno": 'stack(\n s("bd*4").gain(0.8),\n s("~ hh ~ hh").gain(0.6),\n n("0 ~ 3 ~").s("sawtooth").octave(2)\n)', |
|
|
"house": 'stack(\n s("bd ~ ~ ~ bd ~ ~ ~").gain(0.7),\n s("~ hh ~ hh").gain(0.5),\n n("0 2 4 7").s("sine").octave(3)\n)', |
|
|
"ambient": 'n("[0 2 4]/3").s("sine").octave(3).slow(4).room(0.9).gain(0.5)', |
|
|
"jazz": 'stack(\n s("bd ~ sn ~").gain(0.7),\n s("~ ~ hh ~").gain(0.4),\n n("0 3 5 7").s("triangle").octave(4)\n)', |
|
|
"rock": 'stack(\n s("bd sn bd sn").gain(0.8),\n s("hh*8").gain(0.5),\n n("0 0 3 5").s("square").octave(2)\n)' |
|
|
} |
|
|
return patterns.get(genre, 's("bd hh sn hh").gain(0.8)') |
|
|
|
|
|
def create_full_strudel_template(generated_code, include_visuals=True): |
|
|
"""Wrap generated code in a complete Strudel template""" |
|
|
|
|
|
visual_code = """// Hydra visuals |
|
|
await initHydra({feedStrudel:5}) |
|
|
|
|
|
osc(10, 0.1, 0.8) |
|
|
.kaleid(4) |
|
|
.color(1.5, 0.8, 1.2) |
|
|
.out() |
|
|
|
|
|
""" if include_visuals else "" |
|
|
|
|
|
template = f"""{visual_code}// AI-Generated Strudel Music Code |
|
|
{generated_code} |
|
|
|
|
|
// Global effects (optional) |
|
|
// all(x => x.fft(5).scope()) |
|
|
""" |
|
|
|
|
|
return template |
|
|
|
|
|
|
|
|
def generate_interface(prompt, genre, complexity, include_visuals, max_length, temperature): |
|
|
"""Main interface function""" |
|
|
|
|
|
if not prompt.strip(): |
|
|
return "Please enter a description of the music you want to create." |
|
|
|
|
|
|
|
|
generated_code = generate_strudel_code( |
|
|
prompt=prompt, |
|
|
genre=genre, |
|
|
complexity=complexity, |
|
|
max_length=max_length, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
|
|
|
full_code = create_full_strudel_template(generated_code, include_visuals) |
|
|
|
|
|
return full_code |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Strudel Code Generator", theme=gr.themes.Soft()) as app: |
|
|
gr.Markdown(""" |
|
|
# 🎵 AI Strudel Code Generator |
|
|
|
|
|
Generate live coding music patterns using AI! Powered by CodeGen-350M. |
|
|
|
|
|
**How to use:** |
|
|
1. Describe the music you want (e.g., "techno beat with bass", "ambient soundscape") |
|
|
2. Choose genre and complexity |
|
|
3. Click Generate |
|
|
4. Copy the code to [strudel.cc](https://strudel.cc) to hear it! |
|
|
|
|
|
*Note: AI-generated code may need tweaking for best results.* |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Describe your music", |
|
|
placeholder="e.g., 'Create a techno beat with kick drum and hi-hats'", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
genre_dropdown = gr.Dropdown( |
|
|
choices=["general", "techno", "house", "ambient", "jazz", "rock", "experimental"], |
|
|
value="techno", |
|
|
label="Genre" |
|
|
) |
|
|
|
|
|
complexity_dropdown = gr.Dropdown( |
|
|
choices=["simple", "moderate", "complex"], |
|
|
value="simple", |
|
|
label="Complexity" |
|
|
) |
|
|
|
|
|
include_visuals = gr.Checkbox( |
|
|
label="Include Hydra visuals", |
|
|
value=True |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
max_length_slider = gr.Slider( |
|
|
minimum=50, |
|
|
maximum=300, |
|
|
value=150, |
|
|
step=25, |
|
|
label="Max code length" |
|
|
) |
|
|
|
|
|
temperature_slider = gr.Slider( |
|
|
minimum=0.3, |
|
|
maximum=1.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Creativity (temperature)" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("🎵 Generate Strudel Code", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_code = gr.Code( |
|
|
label="Generated Strudel Code", |
|
|
language="javascript", |
|
|
lines=15 |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Next steps:** |
|
|
1. **Copy** the generated code above |
|
|
2. **Go to** [strudel.cc](https://strudel.cc) |
|
|
3. **Paste** the code and **click play** |
|
|
4. **Enjoy** your AI-generated music! 🎶 |
|
|
|
|
|
*Tip: You can modify the generated code to customize the sound!* |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown("### 💡 Quick Examples") |
|
|
with gr.Row(): |
|
|
examples = [ |
|
|
["Create a minimal techno beat", "techno", "simple"], |
|
|
["Ambient soundscape with reverb", "ambient", "simple"], |
|
|
["House music with bass line", "house", "moderate"], |
|
|
["Jazz drum pattern", "jazz", "moderate"], |
|
|
] |
|
|
|
|
|
for example_text, example_genre, example_complexity in examples: |
|
|
btn = gr.Button(f"{example_text}", size="sm") |
|
|
btn.click( |
|
|
lambda t=example_text, g=example_genre, c=example_complexity: (t, g, c), |
|
|
outputs=[prompt_input, genre_dropdown, complexity_dropdown] |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
generate_interface, |
|
|
inputs=[ |
|
|
prompt_input, |
|
|
genre_dropdown, |
|
|
complexity_dropdown, |
|
|
include_visuals, |
|
|
max_length_slider, |
|
|
temperature_slider |
|
|
], |
|
|
outputs=output_code |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |