Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import asyncio | |
| import base64 | |
| import numpy as np | |
| from flask import Flask, request, Response, jsonify, stream_with_context, send_file | |
| from io import BytesIO | |
| import wave | |
| from datetime import datetime | |
| app = Flask(__name__) | |
| # Mock GoogleGenAI class | |
| class GoogleGenAI: | |
| def __init__(self, config): | |
| self.api_key = config['apiKey'] | |
| self.api_version = config['apiVersion'] | |
| self.live = MockLiveMusic() | |
| class MockLiveMusic: | |
| def __init__(self): | |
| self.music = MockMusic() | |
| class MockMusic: | |
| async def connect(self, config): | |
| return MockLiveMusicSession(config['model']) | |
| class MockLiveMusicSession: | |
| def __init__(self, model): | |
| self.model = model | |
| self.callbacks = None | |
| self.is_playing = False | |
| self.setup_complete = False | |
| async def setWeightedPrompts(self, params): | |
| print(f"Setting prompts: {params['weightedPrompts']}") | |
| async def setMusicGenerationConfig(self, params): | |
| print(f"Setting config: {params['musicGenerationConfig']}") | |
| def play(self): | |
| self.is_playing = True | |
| print("Starting music generation") | |
| if self.callbacks and self.callbacks.get('onmessage'): | |
| self.callbacks['onmessage']({'setupComplete': True}) | |
| def close(self): | |
| self.is_playing = False | |
| if self.callbacks and self.callbacks.get('onclose'): | |
| self.callbacks['onclose']() | |
| # Initialize AI client | |
| ai = GoogleGenAI({ | |
| 'apiKey': os.getenv('GEMINI_API_KEY', 'PLACEHOLDER_API_KEY'), | |
| 'apiVersion': 'v1alpha' | |
| }) | |
| model = 'lyria-realtime-exp' | |
| sample_rate = 48000 | |
| channels = 2 | |
| bits_per_sample = 16 | |
| # Genre-specific parameters | |
| GENRE_PARAMS = { | |
| "Synthwave": {"base_freq": 220, "mod_freq": 2, "amplitude": 0.7}, | |
| "Dreamwave": {"base_freq": 110, "mod_freq": 0.5, "amplitude": 0.5}, | |
| "Chillsynth": {"base_freq": 165, "mod_freq": 1, "amplitude": 0.6}, | |
| "Lovewave": {"base_freq": 130, "mod_freq": 0.8, "amplitude": 0.4}, | |
| "slowed": {"base_freq": 55, "mod_freq": 0.2, "amplitude": 0.3} | |
| } | |
| def generate_audio_chunk(prompts, config, total_duration): | |
| slowed_factor = config.get('slowed_factor', 1.0) | |
| chunk_duration = 5 * slowed_factor # 5 seconds per chunk | |
| samples_per_chunk = int(sample_rate * chunk_duration * channels) | |
| t = np.linspace(0, chunk_duration, samples_per_chunk // channels, False) | |
| # Weighted average of genre parameters | |
| total_weight = sum(p['weight'] for p in prompts) | |
| base_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['base_freq'] for p in prompts) / total_weight | |
| mod_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['mod_freq'] for p in prompts) / total_weight | |
| amplitude = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['amplitude'] for p in prompts) / total_weight | |
| amplitude *= 0.5 if slowed_factor < 1 else 1.0 # Reduce for slowed effect | |
| # Generate layered audio with 3 frequencies | |
| chunk = np.zeros(samples_per_chunk, dtype=np.float32) | |
| for _ in range(3): | |
| freq_offset = np.random.uniform(-10, 10) | |
| chunk[:samples_per_chunk//channels] += amplitude * np.sin(2 * np.pi * (base_freq + freq_offset + mod_freq * np.sin(2 * np.pi * 0.1 * t)) * t / sample_rate) | |
| chunk = np.tile(chunk, channels) # Duplicate for stereo | |
| chunk = np.clip(chunk * 32768, -32768, 32767).astype(np.int16) # Convert to 16-bit | |
| return chunk.tobytes() | |
| def pcm_to_wav_buffer(pcm_data, sample_rate=48000, channels=2, bits_per_sample=16): | |
| """Convert PCM data to WAV format in memory.""" | |
| if not pcm_data: | |
| raise ValueError("PCM data is empty") | |
| try: | |
| buffer = BytesIO() | |
| wav_file = wave.open(buffer, 'wb') | |
| try: | |
| wav_file.setnchannels(channels) | |
| wav_file.setsampwidth(bits_per_sample // 8) | |
| wav_file.setframerate(sample_rate) | |
| wav_file.writeframes(pcm_data) | |
| finally: | |
| wav_file.close() | |
| buffer.seek(0) | |
| return buffer | |
| except Exception as e: | |
| print(f"Error creating WAV buffer: {e}") | |
| raise | |
| def generate_music(): | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({'error': 'No JSON data provided'}), 400 | |
| prompts = data.get('prompts', []) | |
| config = data.get('config', { | |
| 'temperature': 1.1, | |
| 'topK': 40, | |
| 'guidance': 4.0, | |
| 'slowed_factor': 1.0 | |
| }) | |
| if not prompts: | |
| return jsonify({'error': 'At least one prompt is required'}), 400 | |
| weighted_prompts = [ | |
| { | |
| 'promptId': f"prompt-{i}", | |
| 'text': prompt['text'], | |
| 'weight': prompt.get('weight', 1.0), | |
| 'color': prompt.get('color', '#9900ff') | |
| } for i, prompt in enumerate(prompts) | |
| ] | |
| session = MockLiveMusicSession(model) | |
| session.callbacks = { | |
| 'onmessage': lambda msg: None, | |
| 'onerror': lambda e: print(f"Error: {e}"), | |
| 'onclose': lambda: print("Session closed") | |
| } | |
| def generate_stream(): | |
| total_duration = 0 | |
| target_duration = 60 # 1 minute | |
| session.setup_complete = True | |
| yield json.dumps({'setupComplete': True}) + '\n' | |
| while total_duration < target_duration and session.is_playing: | |
| chunk_data = generate_audio_chunk(weighted_prompts, config, total_duration) | |
| encoded_chunk = base64.b64encode(chunk_data).decode('utf-8') | |
| message = { | |
| 'serverContent': { | |
| 'audioChunks': [{'data': encoded_chunk}] | |
| } | |
| } | |
| yield json.dumps(message) + '\n' | |
| total_duration += 5 * config.get('slowed_factor', 1.0) | |
| asyncio.run(asyncio.sleep(0.1)) # Simulate real-time generation | |
| if session.callbacks and session.callbacks.get('onclose'): | |
| session.callbacks['onclose']() | |
| session.play() | |
| return Response(stream_with_context(generate_stream()), mimetype='text/event-stream') | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def generate_music_file(): | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({'error': 'No JSON data provided'}), 400 | |
| prompts = data.get('prompts', []) | |
| config = data.get('config', { | |
| 'temperature': 1.1, | |
| 'topK': 40, | |
| 'guidance': 4.0, | |
| 'slowed_factor': 1.0 | |
| }) | |
| if not prompts: | |
| return jsonify({'error': 'At least one prompt is required'}), 400 | |
| weighted_prompts = [ | |
| { | |
| 'promptId': f"prompt-{i}", | |
| 'text': prompt['text'], | |
| 'weight': prompt.get('weight', 1.0), | |
| 'color': prompt.get('color', '#9900ff') | |
| } for i, prompt in enumerate(prompts) | |
| ] | |
| # Collect all audio chunks | |
| total_duration = 0 | |
| target_duration = 60 # 1 minute | |
| audio_chunks = [] | |
| session = MockLiveMusicSession(model) | |
| session.is_playing = True | |
| while total_duration < target_duration and session.is_playing: | |
| chunk_data = generate_audio_chunk(weighted_prompts, config, total_duration) | |
| audio_chunks.append(chunk_data) | |
| total_duration += 5 * config.get('slowed_factor', 1.0) | |
| session.close() | |
| # Combine chunks and create WAV file in memory | |
| pcm_data = b''.join(audio_chunks) | |
| if not pcm_data: | |
| return jsonify({'error': 'No audio data generated'}), 500 | |
| # Create WAV file in memory | |
| wav_buffer = pcm_to_wav_buffer(pcm_data, sample_rate, channels, bits_per_sample) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"generated_music_{timestamp}.wav" | |
| return send_file( | |
| wav_buffer, | |
| mimetype='audio/wav', | |
| as_attachment=True, | |
| download_name=filename | |
| ) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |