sam12345324 commited on
Commit
278831f
·
verified ·
1 Parent(s): 3bc15d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import base64
5
+ import numpy as np
6
+ from flask import Flask, request, send_file, jsonify
7
+ from io import BytesIO
8
+ from utils import decode, pcm_to_wav
9
+ from datetime import datetime
10
+
11
+ app = Flask(__name__)
12
+
13
+ # Mock GoogleGenAI class (replace with actual SDK if available)
14
+ class GoogleGenAI:
15
+ def __init__(self, config):
16
+ self.api_key = config['apiKey']
17
+ self.api_version = config['apiVersion']
18
+ self.live = MockLiveMusic()
19
+
20
+ class MockLiveMusic:
21
+ def __init__(self):
22
+ self.music = MockMusic()
23
+
24
+ class MockMusic:
25
+ async def connect(self, config):
26
+ return MockLiveMusicSession(config['model'])
27
+
28
+ class MockLiveMusicSession:
29
+ def __init__(self, model):
30
+ self.model = model
31
+ self.callbacks = None
32
+ self.is_playing = False
33
+
34
+ async def setWeightedPrompts(self, params):
35
+ print(f"Setting prompts: {params['weightedPrompts']}")
36
+
37
+ async def setMusicGenerationConfig(self, params):
38
+ print(f"Setting config: {params['musicGenerationConfig']}")
39
+
40
+ def play(self):
41
+ self.is_playing = True
42
+ print("Starting music generation")
43
+
44
+ def close(self):
45
+ self.is_playing = False
46
+ if self.callbacks and self.callbacks.get('onclose'):
47
+ self.callbacks['onclose']()
48
+
49
+ # Initialize AI client
50
+ ai = GoogleGenAI({
51
+ 'apiKey': os.getenv('GEMINI_API_KEY', 'PLACEHOLDER_API_KEY'),
52
+ 'apiVersion': 'v1alpha'
53
+ })
54
+ model = 'lyria-realtime-exp'
55
+ sample_rate = 48000
56
+
57
+ @app.route('/generate', methods=['POST'])
58
+ async def generate_music():
59
+ try:
60
+ data = request.get_json()
61
+ if not data:
62
+ return jsonify({'error': 'No JSON data provided'}), 400
63
+
64
+ prompts = data.get('prompts', [])
65
+ config = data.get('config', {
66
+ 'temperature': 1.1,
67
+ 'topK': 40,
68
+ 'guidance': 4.0
69
+ })
70
+
71
+ if not prompts:
72
+ return jsonify({'error': 'At least one prompt is required'}), 400
73
+
74
+ weighted_prompts = [
75
+ {
76
+ 'promptId': f"prompt-{i}",
77
+ 'text': prompt['text'],
78
+ 'weight': prompt.get('weight', 1.0),
79
+ 'color': prompt.get('color', '#9900ff')
80
+ } for i, prompt in enumerate(prompts)
81
+ ]
82
+
83
+ collected_chunks = []
84
+ total_duration = 0
85
+ target_duration = 600 # 10 minutes in seconds
86
+
87
+ async def mock_generate_chunks():
88
+ nonlocal collected_chunks, total_duration
89
+ chunk_duration = 2 # seconds
90
+ samples_per_chunk = int(sample_rate * chunk_duration * 2) # stereo, 16-bit
91
+ while total_duration < target_duration:
92
+ chunk = np.random.randint(-32768, 32768, samples_per_chunk, dtype=np.int16)
93
+ encoded_chunk = base64.b64encode(chunk.tobytes()).decode('utf-8')
94
+ decoded_chunk = decode(encoded_chunk)
95
+ collected_chunks.append(decoded_chunk)
96
+ total_duration += chunk_duration
97
+ yield {'serverContent': {'audioChunks': [{'data': encoded_chunk}]}}
98
+ await asyncio.sleep(0.1) # Simulate network delay
99
+ yield {'close': True}
100
+
101
+ session = await ai.live.music.connect({
102
+ 'model': model,
103
+ 'callbacks': {
104
+ 'onmessage': lambda e: None,
105
+ 'onerror': lambda e: print(f"Error: {e}"),
106
+ 'onclose': lambda: print("Session closed")
107
+ }
108
+ })
109
+
110
+ await session.setWeightedPrompts({'weightedPrompts': weighted_prompts})
111
+ await session.setMusicGenerationConfig({'musicGenerationConfig': config})
112
+ session.play()
113
+
114
+ async for message in mock_generate_chunks():
115
+ if message.get('close'):
116
+ break
117
+ if message.get('serverContent', {}).get('audioChunks'):
118
+ chunk_data = message['serverContent']['audioChunks'][0]['data']
119
+ decoded_chunk = decode(chunk_data)
120
+ collected_chunks.append(decoded_chunk)
121
+ chunk_duration = len(decoded_chunk) / (sample_rate * 2 * 2)
122
+ total_duration += chunk_duration
123
+
124
+ total_length = sum(len(chunk) for chunk in collected_chunks)
125
+ combined_pcm = np.concatenate([np.frombuffer(chunk, dtype=np.int16) for chunk in collected_chunks])
126
+ combined_pcm_bytes = combined_pcm.tobytes()
127
+ wav_blob = pcm_to_wav(combined_pcm_bytes, 2, sample_rate, 16)
128
+
129
+ output = BytesIO(wav_blob)
130
+ timestamp = datetime.now().isoformat().replace(':', '-')
131
+ filename = f"prompt-dj-music-10min-{timestamp}.wav"
132
+
133
+ return send_file(
134
+ output,
135
+ mimetype='audio/wav',
136
+ as_attachment=True,
137
+ download_name=filename
138
+ )
139
+
140
+ except Exception as e:
141
+ return jsonify({'error': str(e)}), 500
142
+
143
+ if __name__ == '__main__':
144
+ app.run(host='0.0.0.0', port=5000)