codel / app.py
baouws's picture
Update app.py
deb5deb verified
raw
history blame
9.31 kB
import gradio as gr
import requests
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# Use CodeGen instead of StarCoder (no authentication needed)
MODEL_NAME = "Salesforce/codegen-350M-mono"
# Initialize model with pipeline for easier usage
try:
code_generator = pipeline(
"text-generation",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to CPU if GPU fails
code_generator = pipeline(
"text-generation",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
device_map=None
)
# Strudel code examples for few-shot prompting
STRUDEL_EXAMPLES = """
// Basic drum pattern
s("bd hh sn hh").gain(0.8)
// Melodic pattern
n("0 2 4 7").s("sine").octave(4).lpf(2000)
// Techno beat
stack(
s("bd*4").gain(0.8),
s("~ hh ~ hh").gain(0.6),
n("0 ~ 3 ~").s("sawtooth").octave(2).lpf(800)
)
// Ambient pattern
n("[0 2 4]/3").s("sine").octave(3).slow(4).room(0.9).gain(0.5)
"""
def generate_strudel_code(prompt, genre="general", complexity="simple", max_length=150, temperature=0.7):
"""Generate Strudel code using CodeGen"""
# Build context-aware prompt
system_prompt = f"""// Strudel live coding language for music
// Generate {complexity} {genre} music patterns
// Examples:
{STRUDEL_EXAMPLES}
// Create: {prompt}
// Code:
"""
try:
# Generate code using pipeline
outputs = code_generator(
system_prompt,
max_length=len(system_prompt.split()) + max_length,
temperature=temperature,
do_sample=True,
top_p=0.9,
num_return_sequences=1,
pad_token_id=code_generator.tokenizer.eos_token_id
)
# Extract generated text
generated_text = outputs[0]['generated_text']
# Extract only the new generated part
strudel_code = generated_text[len(system_prompt):].strip()
# Clean up the code
strudel_code = clean_strudel_code(strudel_code)
return strudel_code
except Exception as e:
return f"// Error generating code: {str(e)}\n// Fallback pattern:\ns(\"bd hh sn hh\").gain(0.8)"
def clean_strudel_code(code):
"""Clean and format generated Strudel code"""
lines = code.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
# Stop at comments that indicate end of generation
if line.startswith('//') and any(word in line.lower() for word in ['user', 'request', 'example', 'create']):
break
# Include actual code lines
if line and (line.startswith('s(') or line.startswith('n(') or line.startswith('stack(') or
line.startswith('$:') or line.startswith('all(') or line.endswith(')') or
line.endswith(',') or 'gain(' in line or 'lpf(' in line):
cleaned_lines.append(line)
# Stop if we hit obvious non-Strudel code
elif any(keyword in line for keyword in ['function', 'var ', 'let ', 'const ', 'import', 'export']):
break
if len(cleaned_lines) >= 8: # Limit output length
break
# If no valid code was generated, provide a fallback
if not cleaned_lines:
return generate_fallback_pattern(genre)
return '\n'.join(cleaned_lines)
def generate_fallback_pattern(genre):
"""Generate a simple fallback pattern based on genre"""
patterns = {
"techno": 'stack(\n s("bd*4").gain(0.8),\n s("~ hh ~ hh").gain(0.6),\n n("0 ~ 3 ~").s("sawtooth").octave(2)\n)',
"house": 'stack(\n s("bd ~ ~ ~ bd ~ ~ ~").gain(0.7),\n s("~ hh ~ hh").gain(0.5),\n n("0 2 4 7").s("sine").octave(3)\n)',
"ambient": 'n("[0 2 4]/3").s("sine").octave(3).slow(4).room(0.9).gain(0.5)',
"jazz": 'stack(\n s("bd ~ sn ~").gain(0.7),\n s("~ ~ hh ~").gain(0.4),\n n("0 3 5 7").s("triangle").octave(4)\n)',
"rock": 'stack(\n s("bd sn bd sn").gain(0.8),\n s("hh*8").gain(0.5),\n n("0 0 3 5").s("square").octave(2)\n)'
}
return patterns.get(genre, 's("bd hh sn hh").gain(0.8)')
def create_full_strudel_template(generated_code, include_visuals=True):
"""Wrap generated code in a complete Strudel template"""
visual_code = """// Hydra visuals
await initHydra({feedStrudel:5})
osc(10, 0.1, 0.8)
.kaleid(4)
.color(1.5, 0.8, 1.2)
.out()
""" if include_visuals else ""
template = f"""{visual_code}// AI-Generated Strudel Music Code
{generated_code}
// Global effects (optional)
// all(x => x.fft(5).scope())
"""
return template
# Gradio interface
def generate_interface(prompt, genre, complexity, include_visuals, max_length, temperature):
"""Main interface function"""
if not prompt.strip():
return "Please enter a description of the music you want to create."
# Generate the core Strudel code
generated_code = generate_strudel_code(
prompt=prompt,
genre=genre,
complexity=complexity,
max_length=max_length,
temperature=temperature
)
# Create full template
full_code = create_full_strudel_template(generated_code, include_visuals)
return full_code
# Create Gradio app
with gr.Blocks(title="Strudel Code Generator", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🎵 AI Strudel Code Generator
Generate live coding music patterns using AI! Powered by CodeGen-350M.
**How to use:**
1. Describe the music you want (e.g., "techno beat with bass", "ambient soundscape")
2. Choose genre and complexity
3. Click Generate
4. Copy the code to [strudel.cc](https://strudel.cc) to hear it!
*Note: AI-generated code may need tweaking for best results.*
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Describe your music",
placeholder="e.g., 'Create a techno beat with kick drum and hi-hats'",
lines=3
)
with gr.Row():
genre_dropdown = gr.Dropdown(
choices=["general", "techno", "house", "ambient", "jazz", "rock", "experimental"],
value="techno",
label="Genre"
)
complexity_dropdown = gr.Dropdown(
choices=["simple", "moderate", "complex"],
value="simple",
label="Complexity"
)
include_visuals = gr.Checkbox(
label="Include Hydra visuals",
value=True
)
with gr.Accordion("Advanced Settings", open=False):
max_length_slider = gr.Slider(
minimum=50,
maximum=300,
value=150,
step=25,
label="Max code length"
)
temperature_slider = gr.Slider(
minimum=0.3,
maximum=1.0,
value=0.7,
step=0.1,
label="Creativity (temperature)"
)
generate_btn = gr.Button("🎵 Generate Strudel Code", variant="primary", size="lg")
with gr.Column():
output_code = gr.Code(
label="Generated Strudel Code",
language="javascript",
lines=15
)
gr.Markdown("""
**Next steps:**
1. **Copy** the generated code above
2. **Go to** [strudel.cc](https://strudel.cc)
3. **Paste** the code and **click play**
4. **Enjoy** your AI-generated music! 🎶
*Tip: You can modify the generated code to customize the sound!*
""")
# Example buttons
gr.Markdown("### 💡 Quick Examples")
with gr.Row():
examples = [
["Create a minimal techno beat", "techno", "simple"],
["Ambient soundscape with reverb", "ambient", "simple"],
["House music with bass line", "house", "moderate"],
["Jazz drum pattern", "jazz", "moderate"],
]
for example_text, example_genre, example_complexity in examples:
btn = gr.Button(f"{example_text}", size="sm")
btn.click(
lambda t=example_text, g=example_genre, c=example_complexity: (t, g, c),
outputs=[prompt_input, genre_dropdown, complexity_dropdown]
)
# Connect the generate button
generate_btn.click(
generate_interface,
inputs=[
prompt_input,
genre_dropdown,
complexity_dropdown,
include_visuals,
max_length_slider,
temperature_slider
],
outputs=output_code
)
# Launch the app
if __name__ == "__main__":
app.launch()