sam12345324 commited on
Commit
9b1cdf1
·
verified ·
1 Parent(s): ae5f114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -69
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import asyncio
4
  import base64
5
  import numpy as np
6
- from flask import Flask, request, send_file, jsonify
7
  from io import BytesIO
8
  from utils import decode, pcm_to_wav
9
  from datetime import datetime
@@ -30,6 +30,7 @@ class MockLiveMusicSession:
30
  self.model = model
31
  self.callbacks = None
32
  self.is_playing = False
 
33
 
34
  async def setWeightedPrompts(self, params):
35
  print(f"Setting prompts: {params['weightedPrompts']}")
@@ -40,6 +41,8 @@ class MockLiveMusicSession:
40
  def play(self):
41
  self.is_playing = True
42
  print("Starting music generation")
 
 
43
 
44
  def close(self):
45
  self.is_playing = False
@@ -53,6 +56,8 @@ ai = GoogleGenAI({
53
  })
54
  model = 'lyria-realtime-exp'
55
  sample_rate = 48000
 
 
56
 
57
  # Genre-specific parameters
58
  GENRE_PARAMS = {
@@ -63,8 +68,30 @@ GENRE_PARAMS = {
63
  "slowed": {"base_freq": 55, "mod_freq": 0.2, "amplitude": 0.3}
64
  }
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  @app.route('/generate', methods=['POST'])
67
- async def generate_music():
68
  try:
69
  data = request.get_json()
70
  if not data:
@@ -90,75 +117,36 @@ async def generate_music():
90
  } for i, prompt in enumerate(prompts)
91
  ]
92
 
93
- collected_chunks = []
94
- total_duration = 0
95
- target_duration = 60 # 1 minute
96
-
97
- async def mock_generate_chunks():
98
- nonlocal collected_chunks, total_duration
99
- slowed_factor = config.get('slowed_factor', 1.0)
100
- chunk_duration = 5 * slowed_factor # Increased to 5 seconds per chunk
101
- samples_per_chunk = int(sample_rate * chunk_duration * 2) # Stereo
102
- t = np.linspace(0, chunk_duration, samples_per_chunk, False)
103
- np.random.seed(None) # Ensure randomness per run
104
- while total_duration < target_duration:
105
- total_weight = sum(p['weight'] for p in weighted_prompts)
106
- base_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['base_freq'] for p in weighted_prompts) / total_weight
107
- mod_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['mod_freq'] for p in weighted_prompts) / total_weight
108
- amplitude = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['amplitude'] for p in weighted_prompts) / total_weight
109
- amplitude *= 0.5 if slowed_factor < 1 else 1.0 # Reduce for slowed effect
110
- # Layer multiple frequencies
111
- chunk = np.zeros(samples_per_chunk, dtype=np.int16)
112
- for _ in range(3): # Layer 3 sine waves
113
- freq_offset = np.random.uniform(-10, 10) # Slight random variation
114
- chunk += (amplitude * 32767 * np.sin(2 * np.pi * (base_freq + freq_offset + mod_freq * np.sin(2 * np.pi * 0.1 * t)) * t / sample_rate)).astype(np.int16)
115
- chunk = np.clip(chunk, -32768, 32767) # Prevent overflow
116
- encoded_chunk = base64.b64encode(chunk.tobytes()).decode('utf-8')
117
- decoded_chunk = decode(encoded_chunk)
118
- collected_chunks.append(decoded_chunk)
119
- total_duration += chunk_duration
120
- yield {'serverContent': {'audioChunks': [{'data': encoded_chunk}]}}
121
- await asyncio.sleep(0.1 * slowed_factor)
122
- yield {'close': True}
123
-
124
- session = await ai.live.music.connect({
125
- 'model': model,
126
- 'callbacks': {
127
- 'onmessage': lambda e: None,
128
- 'onerror': lambda e: print(f"Error: {e}"),
129
- 'onclose': lambda: print("Session closed")
130
- }
131
- })
132
 
133
- await session.setWeightedPrompts({'weightedPrompts': weighted_prompts})
134
- await session.setMusicGenerationConfig({'musicGenerationConfig': config})
135
  session.play()
136
-
137
- async for message in mock_generate_chunks():
138
- if message.get('close'):
139
- break
140
- if message.get('serverContent', {}).get('audioChunks'):
141
- chunk_data = message['serverContent']['audioChunks'][0]['data']
142
- decoded_chunk = decode(chunk_data)
143
- collected_chunks.append(decoded_chunk)
144
- chunk_duration = len(decoded_chunk) / (sample_rate * 2 * 2)
145
- total_duration += chunk_duration
146
-
147
- total_length = sum(len(chunk) for chunk in collected_chunks)
148
- combined_pcm = np.concatenate([np.frombuffer(chunk, dtype=np.int16) for chunk in collected_chunks])
149
- combined_pcm_bytes = combined_pcm.tobytes()
150
- wav_blob = pcm_to_wav(combined_pcm_bytes, 2, sample_rate, 16)
151
-
152
- output = BytesIO(wav_blob)
153
- timestamp = datetime.now().isoformat().replace(':', '-')
154
- filename = f"prompt-dj-music-1min-{timestamp}.wav"
155
-
156
- return send_file(
157
- output,
158
- mimetype='audio/wav',
159
- as_attachment=True,
160
- download_name=filename
161
- )
162
 
163
  except Exception as e:
164
  return jsonify({'error': str(e)}), 500
 
3
  import asyncio
4
  import base64
5
  import numpy as np
6
+ from flask import Flask, request, Response, jsonify, stream_with_context
7
  from io import BytesIO
8
  from utils import decode, pcm_to_wav
9
  from datetime import datetime
 
30
  self.model = model
31
  self.callbacks = None
32
  self.is_playing = False
33
+ self.setup_complete = False
34
 
35
  async def setWeightedPrompts(self, params):
36
  print(f"Setting prompts: {params['weightedPrompts']}")
 
41
  def play(self):
42
  self.is_playing = True
43
  print("Starting music generation")
44
+ if self.callbacks and self.callbacks.get('onmessage'):
45
+ self.callbacks['onmessage']({'setupComplete': True})
46
 
47
  def close(self):
48
  self.is_playing = False
 
56
  })
57
  model = 'lyria-realtime-exp'
58
  sample_rate = 48000
59
+ channels = 2
60
+ bits_per_sample = 16
61
 
62
  # Genre-specific parameters
63
  GENRE_PARAMS = {
 
68
  "slowed": {"base_freq": 55, "mod_freq": 0.2, "amplitude": 0.3}
69
  }
70
 
71
+ def generate_audio_chunk(prompts, config, total_duration):
72
+ slowed_factor = config.get('slowed_factor', 1.0)
73
+ chunk_duration = 5 * slowed_factor # 5 seconds per chunk
74
+ samples_per_chunk = int(sample_rate * chunk_duration * channels)
75
+ t = np.linspace(0, chunk_duration, samples_per_chunk // channels, False)
76
+
77
+ # Weighted average of genre parameters
78
+ total_weight = sum(p['weight'] for p in prompts)
79
+ base_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['base_freq'] for p in prompts) / total_weight
80
+ mod_freq = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['mod_freq'] for p in prompts) / total_weight
81
+ amplitude = sum(p['weight'] * GENRE_PARAMS.get(p['text'], GENRE_PARAMS["Synthwave"])['amplitude'] for p in prompts) / total_weight
82
+ amplitude *= 0.5 if slowed_factor < 1 else 1.0 # Reduce for slowed effect
83
+
84
+ # Generate layered audio with 3 frequencies
85
+ chunk = np.zeros(samples_per_chunk, dtype=np.float32)
86
+ for _ in range(3):
87
+ freq_offset = np.random.uniform(-10, 10)
88
+ chunk[:samples_per_chunk//channels] += amplitude * np.sin(2 * np.pi * (base_freq + freq_offset + mod_freq * np.sin(2 * np.pi * 0.1 * t)) * t / sample_rate)
89
+ chunk = np.tile(chunk, channels) # Duplicate for stereo
90
+ chunk = np.clip(chunk * 32768, -32768, 32767).astype(np.int16) # Convert to 16-bit
91
+ return chunk.tobytes()
92
+
93
  @app.route('/generate', methods=['POST'])
94
+ def generate_music():
95
  try:
96
  data = request.get_json()
97
  if not data:
 
117
  } for i, prompt in enumerate(prompts)
118
  ]
119
 
120
+ session = MockLiveMusicSession(model)
121
+ session.callbacks = {
122
+ 'onmessage': lambda msg: None,
123
+ 'onerror': lambda e: print(f"Error: {e}"),
124
+ 'onclose': lambda: print("Session closed")
125
+ }
126
+
127
+ def generate_stream():
128
+ total_duration = 0
129
+ target_duration = 60 # 1 minute
130
+ session.setup_complete = True
131
+ yield json.dumps({'setupComplete': True}) + '\n'
132
+
133
+ while total_duration < target_duration and session.is_playing:
134
+ chunk_data = generate_audio_chunk(weighted_prompts, config, total_duration)
135
+ encoded_chunk = base64.b64encode(chunk_data).decode('utf-8')
136
+ message = {
137
+ 'serverContent': {
138
+ 'audioChunks': [{'data': encoded_chunk}]
139
+ }
140
+ }
141
+ yield json.dumps(message) + '\n'
142
+ total_duration += 5 * config.get('slowed_factor', 1.0)
143
+ asyncio.run(asyncio.sleep(0.1)) # Simulate real-time generation
144
+
145
+ if session.callbacks and session.callbacks.get('onclose'):
146
+ session.callbacks['onclose']()
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
148
  session.play()
149
+ return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  except Exception as e:
152
  return jsonify({'error': str(e)}), 500