bharatverse11 commited on
Commit
98c1dcb
Β·
verified Β·
1 Parent(s): 3e01497

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -1
app.py CHANGED
@@ -1109,7 +1109,7 @@ async def mix_tracks(request: MixRequest):
1109
  )
1110
 
1111
 
1112
- # ── Generate ─────────────────────────────────────────────────────────────────
1113
 
1114
  @app.post("/generate", response_model=GenerateBeatResponse)
1115
  async def generate_beat_route(request: GenerateBeatRequest):
@@ -1144,6 +1144,95 @@ async def generate_beat_route(request: GenerateBeatRequest):
1144
  )
1145
 
1146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1147
  # ── Output ───────────────────────────────────────────────────────────────────
1148
 
1149
  @app.get("/output/{file_id}")
 
1109
  )
1110
 
1111
 
1112
+ # ── Generate (Procedural Synth) ──────────────────────────────────────────────
1113
 
1114
  @app.post("/generate", response_model=GenerateBeatResponse)
1115
  async def generate_beat_route(request: GenerateBeatRequest):
 
1144
  )
1145
 
1146
 
1147
+ # ── Generate AI (MusicGen via HF Inference API) ─────────────────────────────
1148
+
1149
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
1150
+ MUSICGEN_API_URL = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
1151
+
1152
+ class GenerateAIRequest(BaseModel):
1153
+ prompt: str = Field(..., min_length=3, max_length=500)
1154
+ duration: int = Field(default=10, ge=3, le=30)
1155
+
1156
+ class GenerateAIResponse(BaseModel):
1157
+ output_file_id: str
1158
+ prompt: str
1159
+ duration: float
1160
+ model: str = "facebook/musicgen-small"
1161
+ sample_rate: int = 32000
1162
+ message: str = "AI beat generated successfully."
1163
+
1164
+ @app.post("/generate-ai", response_model=GenerateAIResponse)
1165
+ async def generate_beat_ai(request: GenerateAIRequest):
1166
+ """Generate a beat using Meta's MusicGen via HuggingFace Inference API (free GPU)."""
1167
+ output_id = generate_file_id()
1168
+ output_path = OUTPUT_DIR / f"{output_id}.wav"
1169
+
1170
+ headers = {}
1171
+ if HF_TOKEN:
1172
+ headers["Authorization"] = f"Bearer {HF_TOKEN}"
1173
+
1174
+ payload = {
1175
+ "inputs": request.prompt,
1176
+ "parameters": {
1177
+ "max_new_tokens": request.duration * 50, # ~50 tokens per second
1178
+ },
1179
+ }
1180
+
1181
+ try:
1182
+ print(f"MusicGen AI generating: '{request.prompt}' ({request.duration}s)")
1183
+ async with httpx.AsyncClient(timeout=180) as client:
1184
+ response = await client.post(
1185
+ MUSICGEN_API_URL,
1186
+ headers=headers,
1187
+ json=payload,
1188
+ )
1189
+
1190
+ if response.status_code == 503:
1191
+ # Model is loading
1192
+ raise HTTPException(
1193
+ status_code=503,
1194
+ detail="MusicGen model is loading, please try again in ~30 seconds."
1195
+ )
1196
+ if response.status_code != 200:
1197
+ error_msg = response.text[:200]
1198
+ raise HTTPException(
1199
+ status_code=502,
1200
+ detail=f"HF Inference API error ({response.status_code}): {error_msg}"
1201
+ )
1202
+
1203
+ # Response is raw audio bytes (FLAC format)
1204
+ audio_bytes = response.content
1205
+
1206
+ # Save the raw audio first
1207
+ temp_path = str(output_path).replace(".wav", "_raw.flac")
1208
+ with open(temp_path, "wb") as f:
1209
+ f.write(audio_bytes)
1210
+
1211
+ # Convert to WAV using librosa
1212
+ y, sr = librosa.load(temp_path, sr=None, mono=True)
1213
+ sf.write(str(output_path), y, sr, subtype="PCM_16")
1214
+
1215
+ # Clean up temp
1216
+ os.remove(temp_path)
1217
+
1218
+ actual_duration = round(len(y) / sr, 2)
1219
+ print(f"MusicGen AI complete: {actual_duration}s")
1220
+
1221
+ except HTTPException:
1222
+ raise
1223
+ except Exception as exc:
1224
+ import traceback
1225
+ traceback.print_exc()
1226
+ raise HTTPException(status_code=500, detail=f"AI generation failed: {str(exc)}") from exc
1227
+
1228
+ return GenerateAIResponse(
1229
+ output_file_id=output_id,
1230
+ prompt=request.prompt,
1231
+ duration=actual_duration,
1232
+ sample_rate=int(sr),
1233
+ )
1234
+
1235
+
1236
  # ── Output ───────────────────────────────────────────────────────────────────
1237
 
1238
  @app.get("/output/{file_id}")