pratik-250620 commited on
Commit
7d8ffc1
·
verified ·
1 Parent(s): 358d3bc

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +96 -60
app.py CHANGED
@@ -1015,11 +1015,25 @@ def load_audio_retriever():
1015
 
1016
  @st.cache_resource
1017
  def get_inference_client():
 
1018
  from huggingface_hub import InferenceClient
1019
  token = os.environ.get("HF_TOKEN")
1020
  return InferenceClient(token=token)
1021
 
1022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
  # ---------------------------------------------------------------------------
1024
  # Translation (German <-> English)
1025
  # ---------------------------------------------------------------------------
@@ -1316,70 +1330,80 @@ def gen_text(prompt: str, mode: str) -> dict:
1316
 
1317
 
1318
  def generate_image(prompt: str) -> dict:
1319
- """Generate image via HF Inference API, trying free models first. Falls back to retrieval."""
1320
- client = get_inference_client()
1321
- credits_depleted = False
1322
- for model_id in IMAGE_GEN_MODELS:
1323
- if credits_depleted and model_id == "stabilityai/stable-diffusion-xl-base-1.0":
1324
- logger.info("Skipping paid image model (credits depleted)")
1325
- continue
1326
- try:
1327
- image = client.text_to_image(prompt, model=model_id)
1328
- tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
1329
- image.save(tmp.name)
1330
- model_name = model_id.split("/")[-1]
1331
- return {
1332
- "path": tmp.name, "backend": "generative",
1333
- "model": model_name, "failed": False,
1334
- }
1335
- except Exception as e:
1336
- if _is_credit_error(e):
1337
- credits_depleted = True
1338
- logger.warning("Image model %s: credits depleted (402)", model_id)
1339
- else:
1340
- logger.warning("Image gen with %s failed: %s", model_id, e)
1341
- continue
1342
- logger.warning("All image generation models failed — falling back to retrieval")
1343
- result = retrieve_image(prompt)
1344
- if credits_depleted:
1345
- result["credit_error"] = True
1346
- return result
 
 
 
 
 
 
 
 
1347
 
1348
 
1349
  def generate_audio(prompt: str) -> dict:
1350
- """Generate audio via HF Inference API, trying free models first. Falls back to retrieval."""
 
 
 
 
 
 
1351
  client = get_inference_client()
1352
- credits_depleted = False
1353
  for model_id in AUDIO_GEN_MODELS:
1354
- if credits_depleted and model_id == "cvssp/audioldm2":
1355
- logger.info("Skipping paid audio model (credits depleted)")
1356
- continue
1357
  try:
1358
- audio_bytes = client.text_to_audio(prompt, model=model_id)
1359
- suffix = ".flac" if "musicgen" in model_id else ".wav"
1360
- tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp")
1361
- if isinstance(audio_bytes, bytes):
1362
- tmp.write(audio_bytes)
 
 
 
 
 
 
1363
  tmp.flush()
1364
- else:
1365
- tmp.write(bytes(audio_bytes))
1366
- tmp.flush()
1367
- model_name = model_id.split("/")[-1]
1368
- return {
1369
- "path": tmp.name, "backend": "generative",
1370
- "model": model_name, "failed": False,
1371
- }
1372
  except Exception as e:
1373
- if _is_credit_error(e):
1374
- credits_depleted = True
1375
- logger.warning("Audio model %s: credits depleted (402)", model_id)
1376
- else:
1377
- logger.warning("Audio gen with %s failed: %s", model_id, e)
1378
- continue
1379
- logger.warning("All audio generation models failed — falling back to retrieval")
1380
  result = retrieve_audio(prompt)
1381
- if credits_depleted:
1382
- result["credit_error"] = True
1383
  return result
1384
 
1385
 
@@ -1475,7 +1499,7 @@ def main():
1475
  L["backend"],
1476
  ["generative", "retrieval"],
1477
  format_func=lambda x: {
1478
- "generative": "Generative (FLUX/SDXL + MusicGen)",
1479
  "retrieval": "Retrieval (CLIP + CLAP index)",
1480
  }[x],
1481
  )
@@ -1520,8 +1544,8 @@ def main():
1520
  "extended_prompt": "Single LLM call with 3x token budget",
1521
  }
1522
  if backend == "generative":
1523
- img_info = "FLUX.1-schnell / SDXL via HF API"
1524
- aud_info = "MusicGen / AudioLDM2 via HF API"
1525
  else:
1526
  img_info = "CLIP retrieval (57 images)"
1527
  aud_info = "CLAP retrieval (104 clips)"
@@ -1879,13 +1903,25 @@ def show_results(R: dict):
1879
  backend = ai.get("backend", "unknown")
1880
 
1881
  if backend == "retrieval" and R.get("backend") == "generative":
1882
- if ai.get("credit_error"):
 
 
 
 
 
 
 
 
 
 
 
 
 
1883
  st.markdown(
1884
  f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
1885
  f'using retrieval fallback.</div>',
1886
  unsafe_allow_html=True)
1887
  else:
1888
- sim = ai.get("similarity", 0)
1889
  st.markdown(
1890
  f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
1891
  f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
 
1015
 
1016
  @st.cache_resource
1017
  def get_inference_client():
1018
+ """Default client for text generation (auto-routes to available providers)."""
1019
  from huggingface_hub import InferenceClient
1020
  token = os.environ.get("HF_TOKEN")
1021
  return InferenceClient(token=token)
1022
 
1023
 
1024
+ @st.cache_resource
1025
+ def get_inference_client_free():
1026
+ """Free serverless client for image generation (hf-inference provider).
1027
+
1028
+ Without explicit provider='hf-inference', the client auto-routes to paid
1029
+ Inference Providers (nscale, fal-ai, etc.) which return 402 when credits
1030
+ are depleted. FLUX.1-schnell is available for free on hf-inference.
1031
+ """
1032
+ from huggingface_hub import InferenceClient
1033
+ token = os.environ.get("HF_TOKEN")
1034
+ return InferenceClient(token=token, provider="hf-inference")
1035
+
1036
+
1037
  # ---------------------------------------------------------------------------
1038
  # Translation (German <-> English)
1039
  # ---------------------------------------------------------------------------
 
1330
 
1331
 
1332
  def generate_image(prompt: str) -> dict:
1333
+ """Generate image via HF Inference API. Uses free serverless endpoint first.
1334
+
1335
+ Strategy:
1336
+ 1. Try FLUX.1-schnell via free hf-inference provider (no credits needed)
1337
+ 2. Try SDXL via default auto-routed provider (may need credits)
1338
+ 3. Fall back to CLIP retrieval
1339
+ """
1340
+ # --- Attempt 1: Free serverless (FLUX.1-schnell) ---
1341
+ try:
1342
+ client_free = get_inference_client_free()
1343
+ image = client_free.text_to_image(prompt, model="black-forest-labs/FLUX.1-schnell")
1344
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
1345
+ image.save(tmp.name)
1346
+ return {
1347
+ "path": tmp.name, "backend": "generative",
1348
+ "model": "FLUX.1-schnell", "failed": False,
1349
+ }
1350
+ except Exception as e:
1351
+ logger.warning("FLUX.1-schnell (free) failed: %s", e)
1352
+
1353
+ # --- Attempt 2: Auto-routed provider (may need credits) ---
1354
+ try:
1355
+ client = get_inference_client()
1356
+ image = client.text_to_image(prompt, model="stabilityai/stable-diffusion-xl-base-1.0")
1357
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
1358
+ image.save(tmp.name)
1359
+ return {
1360
+ "path": tmp.name, "backend": "generative",
1361
+ "model": "SDXL", "failed": False,
1362
+ }
1363
+ except Exception as e:
1364
+ logger.warning("SDXL (auto-route) failed: %s", e)
1365
+
1366
+ # --- Fallback: CLIP retrieval ---
1367
+ logger.info("All image gen failed — using CLIP retrieval")
1368
+ return retrieve_image(prompt)
1369
 
1370
 
1371
  def generate_audio(prompt: str) -> dict:
1372
+ """Generate audio via HF Inference API. Falls back to retrieval.
1373
+
1374
+ Note: The free HF serverless endpoint (hf-inference) does NOT support
1375
+ the 'text-to-audio' task (MusicGen, AudioLDM, etc.). Audio generation
1376
+ requires paid Inference Providers. We attempt the call and gracefully
1377
+ fall back to CLAP-based retrieval when it fails.
1378
+ """
1379
  client = get_inference_client()
 
1380
  for model_id in AUDIO_GEN_MODELS:
 
 
 
1381
  try:
1382
+ # Use requests directly — huggingface_hub InferenceClient
1383
+ # has no text_to_audio method for the hf-inference provider.
1384
+ import requests as _requests
1385
+ _token = os.environ.get("HF_TOKEN", "")
1386
+ _headers = {"Authorization": f"Bearer {_token}"} if _token else {}
1387
+ _url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
1388
+ resp = _requests.post(_url, headers=_headers, json={"inputs": prompt}, timeout=120)
1389
+ if resp.status_code == 200 and len(resp.content) > 100:
1390
+ suffix = ".flac" if "musicgen" in model_id else ".wav"
1391
+ tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp")
1392
+ tmp.write(resp.content)
1393
  tmp.flush()
1394
+ model_name = model_id.split("/")[-1]
1395
+ return {
1396
+ "path": tmp.name, "backend": "generative",
1397
+ "model": model_name, "failed": False,
1398
+ }
1399
+ logger.warning("Audio model %s returned %s", model_id, resp.status_code)
 
 
1400
  except Exception as e:
1401
+ logger.warning("Audio gen with %s failed: %s", model_id, e)
1402
+ continue
1403
+ # All generation attempts failed use CLAP retrieval
1404
+ logger.info("Audio generation unavailable on free tier — using CLAP retrieval")
 
 
 
1405
  result = retrieve_audio(prompt)
1406
+ result["generation_unavailable"] = True
 
1407
  return result
1408
 
1409
 
 
1499
  L["backend"],
1500
  ["generative", "retrieval"],
1501
  format_func=lambda x: {
1502
+ "generative": "Generative (FLUX + CLAP retrieval)",
1503
  "retrieval": "Retrieval (CLIP + CLAP index)",
1504
  }[x],
1505
  )
 
1544
  "extended_prompt": "Single LLM call with 3x token budget",
1545
  }
1546
  if backend == "generative":
1547
+ img_info = "FLUX.1-schnell (free) via HF API"
1548
+ aud_info = "CLAP retrieval (audio gen not on free tier)"
1549
  else:
1550
  img_info = "CLIP retrieval (57 images)"
1551
  aud_info = "CLAP retrieval (104 clips)"
 
1903
  backend = ai.get("backend", "unknown")
1904
 
1905
  if backend == "retrieval" and R.get("backend") == "generative":
1906
+ sim = ai.get("similarity", 0)
1907
+ if ai.get("generation_unavailable"):
1908
+ if kid_mode:
1909
+ msg = ("Soundo hat ein passendes Lied aus seiner Sammlung geholt!"
1910
+ if lang == "de" else
1911
+ "Soundo picked a matching sound from the library!")
1912
+ st.markdown(f'<div class="{warn_cls}">{msg}</div>',
1913
+ unsafe_allow_html=True)
1914
+ else:
1915
+ st.markdown(
1916
+ f'<div class="{warn_cls}">Audio generation not available on free tier '
1917
+ f'\u2014 using CLAP retrieval (sim={sim:.3f}).</div>',
1918
+ unsafe_allow_html=True)
1919
+ elif ai.get("credit_error"):
1920
  st.markdown(
1921
  f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
1922
  f'using retrieval fallback.</div>',
1923
  unsafe_allow_html=True)
1924
  else:
 
1925
  st.markdown(
1926
  f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
1927
  f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',