| import gradio as gr |
| from gradio_client import Client |
| import os |
| import json |
| import random |
| from datetime import datetime |
| import numpy as np |
| from pydub import AudioSegment |
| import logging |
| import configparser |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| BASE_DIR = "/home/pi5/muzax" |
| MP3_DIR = os.path.join(BASE_DIR, "mp3") |
| JSON_DIR = os.path.join(BASE_DIR, "json") |
| JSON_LOG = os.path.join(JSON_DIR, "render_log.json") |
| INI_FILE = os.path.join(BASE_DIR, "band_styles.ini") |
| API_URL = "http://192.168.0.155:9999/" |
| SONG_DURATION = 120 |
| TARGET_DURATION_MS = 180000 |
|
|
| |
| config = configparser.ConfigParser() |
| if not os.path.exists(INI_FILE): |
| logger.error(f"INI file not found: {INI_FILE}") |
| raise FileNotFoundError(f"INI file not found: {INI_FILE}") |
| config.read(INI_FILE) |
| ALLOWED_BANDS = config.sections() |
|
|
| |
| for directory in [BASE_DIR, MP3_DIR, JSON_DIR]: |
| if not os.path.exists(directory): |
| os.makedirs(directory) |
| logger.info(f"Created directory: {directory}") |
|
|
| |
| if not os.path.exists(JSON_LOG): |
| with open(JSON_LOG, "w") as f: |
| json.dump([], f) |
| logger.info(f"Initialized JSON log: {JSON_LOG}") |
|
|
| def generate_random_params(band): |
| """Generate random parameters from INI file.""" |
| style = config[band] |
| bpm = random.randint(int(style["bpm_min"]), int(style["bpm_max"])) |
| drum_beat = random.choice(style["drum_beat"].split(",")) |
| synthesizer = random.choice(style["synthesizer"].split(",")) |
| rhythmic_steps = random.choice(style["rhythmic_steps"].split(",")) |
| bass_style = random.choice(style["bass_style"].split(",")) |
| guitar_style = random.choice(style["guitar_style"].split(",")) |
| return bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style |
|
|
| def generate_music_placeholder(prompt, duration, bpm, drum_beat, bass_style, guitar_style): |
| """Placeholder for music generation.""" |
| sample_rate = 44100 |
| duration_ms = int(duration) * 1000 |
| audio = AudioSegment.silent(duration=duration_ms) |
| t = np.linspace(0, float(duration), int(sample_rate * float(duration)), endpoint=False) |
| freq = 440 if guitar_style == "clean" else 220 |
| sine_wave = 0.5 * np.sin(2 * np.pi * freq * t) |
| audio_samples = (sine_wave * 32767).astype(np.int16) |
| audio = AudioSegment( |
| audio_samples.tobytes(), |
| frame_rate=sample_rate, |
| sample_width=2, |
| channels=1 |
| ) |
| return audio |
|
|
| def extend_audio(audio, target_duration_ms): |
| """Extend audio to target duration by looping.""" |
| current_duration = len(audio) |
| if current_duration >= target_duration_ms: |
| return audio[:target_duration_ms] |
| |
| extended_audio = audio |
| while len(extended_audio) < target_duration_ms: |
| extended_audio += audio |
| return extended_audio[:target_duration_ms] |
|
|
| def save_to_mp3(audio, filename): |
| """Save audio to MP3.""" |
| filepath = os.path.join(MP3_DIR, filename) |
| audio.export(filepath, format="mp3") |
| logger.info(f"Saved MP3 to {filepath}") |
| return filepath |
|
|
| def update_json_log(band, params, filepath, status): |
| """Update JSON log.""" |
| with open(JSON_LOG, "r") as f: |
| log = json.load(f) |
| render_entry = { |
| "timestamp": datetime.now().isoformat(), |
| "band": band, |
| "parameters": params, |
| "filepath": filepath, |
| "status": status |
| } |
| log.append(render_entry) |
| with open(JSON_LOG, "w") as f: |
| json.dump(log, f, indent=2) |
| logger.info(f"Updated JSON log: {render_entry}") |
|
|
| def generate_song(band, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
| """Generate and save a 180-second song.""" |
| try: |
| client = Client(API_URL) |
| params = { |
| "bpm": bpm, |
| "drum_beat": drum_beat, |
| "synthesizer": synthesizer, |
| "rhythmic_steps": rhythmic_steps, |
| "bass_style": bass_style, |
| "guitar_style": guitar_style, |
| "api_name": config[band]["api_name"] |
| } |
| prompt = client.predict(**params) |
| logger.info(f"Prompt for {band}: {prompt}") |
|
|
| music_params = { |
| "instrumental_prompt": prompt, |
| "cfg_scale": 3.0, |
| "top_k": 300, |
| "top_p": 0.9, |
| "temperature": 0.8, |
| "total_duration": SONG_DURATION, |
| "bpm": bpm, |
| "drum_beat": drum_beat, |
| "synthesizer": synthesizer, |
| "rhythmic_steps": rhythmic_steps, |
| "bass_style": bass_style, |
| "guitar_style": guitar_style, |
| "target_volume": -23.0, |
| "preset": "rock", |
| "vram_status": "", |
| "api_name": "/generate_music" |
| } |
|
|
| result = client.predict(**music_params) |
| filepath, status, _ = result |
|
|
| if not filepath: |
| logger.warning("API returned no audio, using placeholder.") |
| audio = generate_music_placeholder( |
| prompt, SONG_DURATION, bpm, drum_beat, bass_style, guitar_style |
| ) |
| audio = extend_audio(audio, TARGET_DURATION_MS) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"{band}_{timestamp}.mp3" |
| filepath = save_to_mp3(audio, filename) |
| status = "Generated with placeholder, extended to 180 seconds" |
| else: |
| |
| audio = AudioSegment.from_file(filepath) |
| audio = extend_audio(audio, TARGET_DURATION_MS) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"{band}_{timestamp}.mp3" |
| filepath = save_to_mp3(audio, filename) |
| status = "Generated with API, extended to 180 seconds" |
|
|
| update_json_log(band, music_params, filepath, status) |
| return filepath, status |
|
|
| except Exception as e: |
| logger.error(f"Error generating song: {str(e)}") |
| return None, f"Error: {str(e)}" |
|
|
| def get_last_five_songs(): |
| """Get the last 5 songs from JSON log.""" |
| try: |
| with open(JSON_LOG, "r") as f: |
| log = json.load(f) |
| log.sort(key=lambda x: x["timestamp"], reverse=True) |
| return [ |
| { |
| "timestamp": entry["timestamp"], |
| "band": entry["band"].replace("_", " ").title(), |
| "filepath": entry["filepath"], |
| "parameters": entry["parameters"], |
| "status": entry["status"] |
| } |
| for entry in log[:5] |
| ] |
| except Exception as e: |
| logger.error(f"Error reading JSON log: {str(e)}") |
| return [] |
|
|
| def create_gradio_interface(): |
| """Create Gradio interface.""" |
| css = """ |
| .gradio-container {background-color: #2b2b2b; color: #ffffff; font-family: Arial, sans-serif;} |
| .gr-button-primary {background-color: #4a90e2; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;} |
| .gr-button-primary:hover {background-color: #357abd;} |
| .gr-button-secondary {background-color: #4a4a4a; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;} |
| .gr-button-secondary:hover {background-color: #333333;} |
| .gr-panel {background-color: #3c3c3c; border: none; border-radius: 8px; padding: 15px;} |
| .gr-textbox, .gr-slider, .gr-dropdown, .gr-audio {background-color: #4a4a4a; color: #ffffff; border: none; border-radius: 5px;} |
| .gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #ffffff;} |
| """ |
| with gr.Blocks(title="Muzax Rock Generator", css=css) as demo: |
| gr.Markdown( |
| """ |
| # Muzax Rock Song Generator |
| Create 3-minute rock songs inspired by top bands. Save MP3s to /home/pi5/muzax/mp3. |
| """ |
| ) |
|
|
| with gr.Tabs(): |
| for band in ALLOWED_BANDS: |
| with gr.Tab(label=band.replace("_", " ").title()): |
| gr.Markdown(f"### {band.replace('_', ' ').title()} Song Generator") |
| with gr.Column(): |
| bpm = gr.Slider( |
| minimum=60, |
| maximum=180, |
| value=120, |
| step=1, |
| label="Tempo (BPM) ๐ต", |
| info="Song speed in beats per minute." |
| ) |
| drum_beat = gr.Dropdown( |
| choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], |
| value="standard rock", |
| label="Drum Beat ๐ฅ", |
| info="Drum style." |
| ) |
| synthesizer = gr.Dropdown( |
| choices=["none", "analog synth", "digital pad", "arpeggiated synth"], |
| value="none", |
| label="Synthesizer ๐น", |
| info="Synth sound." |
| ) |
| rhythmic_steps = gr.Dropdown( |
| choices=["none", "syncopated steps", "steady steps", "complex steps"], |
| value="steady steps", |
| label="Rhythmic Steps ๐ฃ", |
| info="Rhythm complexity." |
| ) |
| bass_style = gr.Dropdown( |
| choices=["none", "slap bass", "deep bass", "melodic bass"], |
| value="deep bass", |
| label="Bass Style ๐ธ", |
| info="Bass guitar style." |
| ) |
| guitar_style = gr.Dropdown( |
| choices=["none", "distorted", "clean", "jangle"], |
| value="distorted", |
| label="Guitar Style ๐ธ", |
| info="Guitar sound." |
| ) |
|
|
| with gr.Row(): |
| randomize_btn = gr.Button("Randomize", variant="secondary") |
| generate_btn = gr.Button("Generate Song", variant="primary") |
|
|
| audio_output = gr.Audio( |
| label="Generated Song ๐ต", |
| type="filepath", |
| interactive=False |
| ) |
| status_output = gr.Textbox( |
| label="Status ๐ข", |
| placeholder="Status updates here.", |
| interactive=False |
| ) |
|
|
| randomize_btn.click( |
| fn=lambda: generate_random_params(band), |
| outputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style] |
| ) |
|
|
| generate_btn.click( |
| fn=generate_song, |
| inputs=[gr.State(value=band), bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], |
| outputs=[audio_output, status_output] |
| ) |
|
|
| with gr.Tab("Recent Songs"): |
| gr.Markdown("### Last 5 Songs") |
| recent_songs = gr.State(value=get_last_five_songs()) |
|
|
| for i in range(5): |
| with gr.Group(): |
| gr.Markdown(f"#### Song {i+1}") |
| audio_player = gr.Audio(label=f"Song {i+1}", type="filepath", interactive=False) |
| info_text = gr.Textbox(label=f"Details {i+1}", interactive=False) |
| play_btn = gr.Button(f"Play Song {i+1}", variant="primary") |
|
|
| def play_song(song_list, index=i): |
| if index < len(song_list) and os.path.exists(song_list[index]["filepath"]): |
| return ( |
| song_list[index]["filepath"], |
| f"Band: {song_list[index]['band']}\nTime: {song_list[index]['timestamp']}\nParams: {json.dumps(song_list[index]['parameters'], indent=2)}\nStatus: {song_list[index]['status']}" |
| ) |
| return None, "Song unavailable." |
|
|
| play_btn.click( |
| fn=play_song, |
| inputs=[recent_songs], |
| outputs=[audio_player, info_text] |
| ) |
|
|
| refresh_btn = gr.Button("Refresh Songs", variant="secondary") |
| refresh_btn.click( |
| fn=get_last_five_songs, |
| outputs=[recent_songs] |
| ) |
|
|
| with gr.Tab("Render Log"): |
| gr.Markdown("### Render Log") |
| log_output = gr.JSON(label="All Renders", value=lambda: json.load(open(JSON_LOG))) |
| refresh_log_btn = gr.Button("Refresh Log", variant="primary") |
| refresh_log_btn.click( |
| fn=lambda: json.load(open(JSON_LOG)), |
| outputs=[log_output] |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| try: |
| demo = create_gradio_interface() |
| demo.launch(server_name="0.0.0.0", server_port=3223, share=True) |
| logger.info("Gradio launched on 0.0.0.0:3223 with public sharing enabled") |
| except Exception as e: |
| logger.error(f"Failed to launch Gradio: {str(e)}") |