File size: 8,395 Bytes
278831f
 
 
 
 
3a54dfa
278831f
2a35cbb
278831f
 
 
 
ce397dc
278831f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1cdf1
278831f
 
 
 
 
 
 
 
 
 
9b1cdf1
 
278831f
 
 
 
 
 
ce397dc
278831f
 
 
 
 
 
9b1cdf1
 
278831f
ce397dc
2a07b7f
8b0ea77
001655e
 
 
 
2a07b7f
 
9b1cdf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a54dfa
 
37c0a28
 
 
 
3a54dfa
 
37c0a28
 
 
 
 
 
3a54dfa
 
 
37c0a28
3a54dfa
37c0a28
2a35cbb
278831f
9b1cdf1
278831f
 
 
 
 
 
 
 
 
bc56ed5
 
278831f
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1cdf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b29a88
9b1cdf1
 
 
 
278831f
 
9b1cdf1
278831f
 
ce397dc
278831f
2a35cbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a54dfa
2a35cbb
37c0a28
 
 
3a54dfa
 
2a35cbb
3a54dfa
 
 
 
 
 
 
 
2a35cbb
 
 
 
278831f
bc56ed5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import json
import asyncio
import base64
import numpy as np
from flask import Flask, request, Response, jsonify, stream_with_context, send_file
from io import BytesIO
import wave
from datetime import datetime

app = Flask(__name__)

# Mock GoogleGenAI class
class GoogleGenAI:
    def __init__(self, config):
        self.api_key = config['apiKey']
        self.api_version = config['apiVersion']
        self.live = MockLiveMusic()

class MockLiveMusic:
    def __init__(self):
        self.music = MockMusic()

class MockMusic:
    async def connect(self, config):
        return MockLiveMusicSession(config['model'])

class MockLiveMusicSession:
    def __init__(self, model):
        self.model = model
        self.callbacks = None
        self.is_playing = False
        self.setup_complete = False

    async def setWeightedPrompts(self, params):
        print(f"Setting prompts: {params['weightedPrompts']}")

    async def setMusicGenerationConfig(self, params):
        print(f"Setting config: {params['musicGenerationConfig']}")

    def play(self):
        self.is_playing = True
        print("Starting music generation")
        if self.callbacks and self.callbacks.get('onmessage'):
            self.callbacks['onmessage']({'setupComplete': True})

    def close(self):
        self.is_playing = False
        if self.callbacks and self.callbacks.get('onclose'):
            self.callbacks['onclose']()

# Initialize AI client
ai = GoogleGenAI({
    'apiKey': os.getenv('GEMINI_API_KEY', 'PLACEHOLDER_API_KEY'),
    'apiVersion': 'v1alpha'
})
model = 'lyria-realtime-exp'
sample_rate = 48000
channels = 2
bits_per_sample = 16

# Genre-specific parameters
GENRE_PARAMS = {
    "Synthwave": {"base_freq": 220, "mod_freq": 2, "amplitude": 0.7},
    "Dreamwave": {"base_freq": 110, "mod_freq": 0.5, "amplitude": 0.5},
    "Chillsynth": {"base_freq": 165, "mod_freq": 1, "amplitude": 0.6},
    "Lovewave": {"base_freq": 130, "mod_freq": 0.8, "amplitude": 0.4},
    "slowed": {"base_freq": 55, "mod_freq": 0.2, "amplitude": 0.3}
}

def generate_audio_chunk(prompts, config, total_duration):
    slowed_factor = config.get('slowed_factor', 1.0)
    chunk_duration = 5 * slowed_factor  # 5 seconds per chunk
    samples_per_chunk = int(sample_rate * chunk_duration * channels)
    t = np.linspace(0, chunk_duration, samples_per_chunk // channels, False)
    
    # Weighted average of genre parameters
    total_weight = sum(p['weight'] for p in prompts)
    base_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['base_freq'] for p in prompts) / total_weight
    mod_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['mod_freq'] for p in prompts) / total_weight
    amplitude = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['amplitude'] for p in prompts) / total_weight
    amplitude *= 0.5 if slowed_factor < 1 else 1.0  # Reduce for slowed effect
    
    # Generate layered audio with 3 frequencies
    chunk = np.zeros(samples_per_chunk, dtype=np.float32)
    for _ in range(3):
        freq_offset = np.random.uniform(-10, 10)
        chunk[:samples_per_chunk//channels] += amplitude * np.sin(2 * np.pi * (base_freq + freq_offset + mod_freq * np.sin(2 * np.pi * 0.1 * t)) * t / sample_rate)
    chunk = np.tile(chunk, channels)  # Duplicate for stereo
    chunk = np.clip(chunk * 32768, -32768, 32767).astype(np.int16)  # Convert to 16-bit
    return chunk.tobytes()

def pcm_to_wav_buffer(pcm_data, sample_rate=48000, channels=2, bits_per_sample=16):
    """Convert PCM data to WAV format in memory."""
    if not pcm_data:
        raise ValueError("PCM data is empty")
    
    try:
        buffer = BytesIO()
        wav_file = wave.open(buffer, 'wb')
        try:
            wav_file.setnchannels(channels)
            wav_file.setsampwidth(bits_per_sample // 8)
            wav_file.setframerate(sample_rate)
            wav_file.writeframes(pcm_data)
        finally:
            wav_file.close()
        buffer.seek(0)
        return buffer
    except Exception as e:
        print(f"Error creating WAV buffer: {e}")
        raise

@app.route('/generate', methods=['POST'])
def generate_music():
    try:
        data = request.get_json()
        if not data:
            return jsonify({'error': 'No JSON data provided'}), 400

        prompts = data.get('prompts', [])
        config = data.get('config', {
            'temperature': 1.1,
            'topK': 40,
            'guidance': 4.0,
            'slowed_factor': 1.0
        })

        if not prompts:
            return jsonify({'error': 'At least one prompt is required'}), 400

        weighted_prompts = [
            {
                'promptId': f"prompt-{i}",
                'text': prompt['text'],
                'weight': prompt.get('weight', 1.0),
                'color': prompt.get('color', '#9900ff')
            } for i, prompt in enumerate(prompts)
        ]

        session = MockLiveMusicSession(model)
        session.callbacks = {
            'onmessage': lambda msg: None,
            'onerror': lambda e: print(f"Error: {e}"),
            'onclose': lambda: print("Session closed")
        }

        def generate_stream():
            total_duration = 0
            target_duration = 60  # 1 minute
            session.setup_complete = True
            yield json.dumps({'setupComplete': True}) + '\n'

            while total_duration < target_duration and session.is_playing:
                chunk_data = generate_audio_chunk(weighted_prompts, config, total_duration)
                encoded_chunk = base64.b64encode(chunk_data).decode('utf-8')
                message = {
                    'serverContent': {
                        'audioChunks': [{'data': encoded_chunk}]
                    }
                }
                yield json.dumps(message) + '\n'
                total_duration += 5 * config.get('slowed_factor', 1.0)
                asyncio.run(asyncio.sleep(0.1))  # Simulate real-time generation

            if session.callbacks and session.callbacks.get('onclose'):
                session.callbacks['onclose']()

        session.play()
        return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')

    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/generate_file', methods=['POST'])
def generate_music_file():
    try:
        data = request.get_json()
        if not data:
            return jsonify({'error': 'No JSON data provided'}), 400

        prompts = data.get('prompts', [])
        config = data.get('config', {
            'temperature': 1.1,
            'topK': 40,
            'guidance': 4.0,
            'slowed_factor': 1.0
        })

        if not prompts:
            return jsonify({'error': 'At least one prompt is required'}), 400

        weighted_prompts = [
            {
                'promptId': f"prompt-{i}",
                'text': prompt['text'],
                'weight': prompt.get('weight', 1.0),
                'color': prompt.get('color', '#9900ff')
            } for i, prompt in enumerate(prompts)
        ]

        # Collect all audio chunks
        total_duration = 0
        target_duration = 60  # 1 minute
        audio_chunks = []
        session = MockLiveMusicSession(model)
        session.is_playing = True

        while total_duration < target_duration and session.is_playing:
            chunk_data = generate_audio_chunk(weighted_prompts, config, total_duration)
            audio_chunks.append(chunk_data)
            total_duration += 5 * config.get('slowed_factor', 1.0)

        session.close()

        # Combine chunks and create WAV file in memory
        pcm_data = b''.join(audio_chunks)
        if not pcm_data:
            return jsonify({'error': 'No audio data generated'}), 500

        # Create WAV file in memory
        wav_buffer = pcm_to_wav_buffer(pcm_data, sample_rate, channels, bits_per_sample)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"generated_music_{timestamp}.wav"

        return send_file(
            wav_buffer,
            mimetype='audio/wav',
            as_attachment=True,
            download_name=filename
        )

    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)