Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1015,11 +1015,25 @@ def load_audio_retriever():
|
|
| 1015 |
|
| 1016 |
@st.cache_resource
|
| 1017 |
def get_inference_client():
|
|
|
|
| 1018 |
from huggingface_hub import InferenceClient
|
| 1019 |
token = os.environ.get("HF_TOKEN")
|
| 1020 |
return InferenceClient(token=token)
|
| 1021 |
|
| 1022 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
# ---------------------------------------------------------------------------
|
| 1024 |
# Translation (German <-> English)
|
| 1025 |
# ---------------------------------------------------------------------------
|
|
@@ -1316,70 +1330,80 @@ def gen_text(prompt: str, mode: str) -> dict:
|
|
| 1316 |
|
| 1317 |
|
| 1318 |
def generate_image(prompt: str) -> dict:
|
| 1319 |
-
"""Generate image via HF Inference API
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
|
| 1348 |
|
| 1349 |
def generate_audio(prompt: str) -> dict:
|
| 1350 |
-
"""Generate audio via HF Inference API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1351 |
client = get_inference_client()
|
| 1352 |
-
credits_depleted = False
|
| 1353 |
for model_id in AUDIO_GEN_MODELS:
|
| 1354 |
-
if credits_depleted and model_id == "cvssp/audioldm2":
|
| 1355 |
-
logger.info("Skipping paid audio model (credits depleted)")
|
| 1356 |
-
continue
|
| 1357 |
try:
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1363 |
tmp.flush()
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
"model": model_name, "failed": False,
|
| 1371 |
-
}
|
| 1372 |
except Exception as e:
|
| 1373 |
-
|
| 1374 |
-
|
| 1375 |
-
|
| 1376 |
-
|
| 1377 |
-
logger.warning("Audio gen with %s failed: %s", model_id, e)
|
| 1378 |
-
continue
|
| 1379 |
-
logger.warning("All audio generation models failed — falling back to retrieval")
|
| 1380 |
result = retrieve_audio(prompt)
|
| 1381 |
-
|
| 1382 |
-
result["credit_error"] = True
|
| 1383 |
return result
|
| 1384 |
|
| 1385 |
|
|
@@ -1475,7 +1499,7 @@ def main():
|
|
| 1475 |
L["backend"],
|
| 1476 |
["generative", "retrieval"],
|
| 1477 |
format_func=lambda x: {
|
| 1478 |
-
"generative": "Generative (FLUX
|
| 1479 |
"retrieval": "Retrieval (CLIP + CLAP index)",
|
| 1480 |
}[x],
|
| 1481 |
)
|
|
@@ -1520,8 +1544,8 @@ def main():
|
|
| 1520 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 1521 |
}
|
| 1522 |
if backend == "generative":
|
| 1523 |
-
img_info = "FLUX.1-schnell
|
| 1524 |
-
aud_info = "
|
| 1525 |
else:
|
| 1526 |
img_info = "CLIP retrieval (57 images)"
|
| 1527 |
aud_info = "CLAP retrieval (104 clips)"
|
|
@@ -1879,13 +1903,25 @@ def show_results(R: dict):
|
|
| 1879 |
backend = ai.get("backend", "unknown")
|
| 1880 |
|
| 1881 |
if backend == "retrieval" and R.get("backend") == "generative":
|
| 1882 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1883 |
st.markdown(
|
| 1884 |
f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
|
| 1885 |
f'using retrieval fallback.</div>',
|
| 1886 |
unsafe_allow_html=True)
|
| 1887 |
else:
|
| 1888 |
-
sim = ai.get("similarity", 0)
|
| 1889 |
st.markdown(
|
| 1890 |
f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
|
| 1891 |
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|
|
|
|
| 1015 |
|
| 1016 |
@st.cache_resource
|
| 1017 |
def get_inference_client():
|
| 1018 |
+
"""Default client for text generation (auto-routes to available providers)."""
|
| 1019 |
from huggingface_hub import InferenceClient
|
| 1020 |
token = os.environ.get("HF_TOKEN")
|
| 1021 |
return InferenceClient(token=token)
|
| 1022 |
|
| 1023 |
|
| 1024 |
+
@st.cache_resource
|
| 1025 |
+
def get_inference_client_free():
|
| 1026 |
+
"""Free serverless client for image generation (hf-inference provider).
|
| 1027 |
+
|
| 1028 |
+
Without explicit provider='hf-inference', the client auto-routes to paid
|
| 1029 |
+
Inference Providers (nscale, fal-ai, etc.) which return 402 when credits
|
| 1030 |
+
are depleted. FLUX.1-schnell is available for free on hf-inference.
|
| 1031 |
+
"""
|
| 1032 |
+
from huggingface_hub import InferenceClient
|
| 1033 |
+
token = os.environ.get("HF_TOKEN")
|
| 1034 |
+
return InferenceClient(token=token, provider="hf-inference")
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
# ---------------------------------------------------------------------------
|
| 1038 |
# Translation (German <-> English)
|
| 1039 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1330 |
|
| 1331 |
|
| 1332 |
def generate_image(prompt: str) -> dict:
|
| 1333 |
+
"""Generate image via HF Inference API. Uses free serverless endpoint first.
|
| 1334 |
+
|
| 1335 |
+
Strategy:
|
| 1336 |
+
1. Try FLUX.1-schnell via free hf-inference provider (no credits needed)
|
| 1337 |
+
2. Try SDXL via default auto-routed provider (may need credits)
|
| 1338 |
+
3. Fall back to CLIP retrieval
|
| 1339 |
+
"""
|
| 1340 |
+
# --- Attempt 1: Free serverless (FLUX.1-schnell) ---
|
| 1341 |
+
try:
|
| 1342 |
+
client_free = get_inference_client_free()
|
| 1343 |
+
image = client_free.text_to_image(prompt, model="black-forest-labs/FLUX.1-schnell")
|
| 1344 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
|
| 1345 |
+
image.save(tmp.name)
|
| 1346 |
+
return {
|
| 1347 |
+
"path": tmp.name, "backend": "generative",
|
| 1348 |
+
"model": "FLUX.1-schnell", "failed": False,
|
| 1349 |
+
}
|
| 1350 |
+
except Exception as e:
|
| 1351 |
+
logger.warning("FLUX.1-schnell (free) failed: %s", e)
|
| 1352 |
+
|
| 1353 |
+
# --- Attempt 2: Auto-routed provider (may need credits) ---
|
| 1354 |
+
try:
|
| 1355 |
+
client = get_inference_client()
|
| 1356 |
+
image = client.text_to_image(prompt, model="stabilityai/stable-diffusion-xl-base-1.0")
|
| 1357 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
|
| 1358 |
+
image.save(tmp.name)
|
| 1359 |
+
return {
|
| 1360 |
+
"path": tmp.name, "backend": "generative",
|
| 1361 |
+
"model": "SDXL", "failed": False,
|
| 1362 |
+
}
|
| 1363 |
+
except Exception as e:
|
| 1364 |
+
logger.warning("SDXL (auto-route) failed: %s", e)
|
| 1365 |
+
|
| 1366 |
+
# --- Fallback: CLIP retrieval ---
|
| 1367 |
+
logger.info("All image gen failed — using CLIP retrieval")
|
| 1368 |
+
return retrieve_image(prompt)
|
| 1369 |
|
| 1370 |
|
| 1371 |
def generate_audio(prompt: str) -> dict:
|
| 1372 |
+
"""Generate audio via HF Inference API. Falls back to retrieval.
|
| 1373 |
+
|
| 1374 |
+
Note: The free HF serverless endpoint (hf-inference) does NOT support
|
| 1375 |
+
the 'text-to-audio' task (MusicGen, AudioLDM, etc.). Audio generation
|
| 1376 |
+
requires paid Inference Providers. We attempt the call and gracefully
|
| 1377 |
+
fall back to CLAP-based retrieval when it fails.
|
| 1378 |
+
"""
|
| 1379 |
client = get_inference_client()
|
|
|
|
| 1380 |
for model_id in AUDIO_GEN_MODELS:
|
|
|
|
|
|
|
|
|
|
| 1381 |
try:
|
| 1382 |
+
# Use requests directly — huggingface_hub InferenceClient
|
| 1383 |
+
# has no text_to_audio method for the hf-inference provider.
|
| 1384 |
+
import requests as _requests
|
| 1385 |
+
_token = os.environ.get("HF_TOKEN", "")
|
| 1386 |
+
_headers = {"Authorization": f"Bearer {_token}"} if _token else {}
|
| 1387 |
+
_url = f"https://router.huggingface.co/hf-inference/models/{model_id}"
|
| 1388 |
+
resp = _requests.post(_url, headers=_headers, json={"inputs": prompt}, timeout=120)
|
| 1389 |
+
if resp.status_code == 200 and len(resp.content) > 100:
|
| 1390 |
+
suffix = ".flac" if "musicgen" in model_id else ".wav"
|
| 1391 |
+
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp")
|
| 1392 |
+
tmp.write(resp.content)
|
| 1393 |
tmp.flush()
|
| 1394 |
+
model_name = model_id.split("/")[-1]
|
| 1395 |
+
return {
|
| 1396 |
+
"path": tmp.name, "backend": "generative",
|
| 1397 |
+
"model": model_name, "failed": False,
|
| 1398 |
+
}
|
| 1399 |
+
logger.warning("Audio model %s returned %s", model_id, resp.status_code)
|
|
|
|
|
|
|
| 1400 |
except Exception as e:
|
| 1401 |
+
logger.warning("Audio gen with %s failed: %s", model_id, e)
|
| 1402 |
+
continue
|
| 1403 |
+
# All generation attempts failed — use CLAP retrieval
|
| 1404 |
+
logger.info("Audio generation unavailable on free tier — using CLAP retrieval")
|
|
|
|
|
|
|
|
|
|
| 1405 |
result = retrieve_audio(prompt)
|
| 1406 |
+
result["generation_unavailable"] = True
|
|
|
|
| 1407 |
return result
|
| 1408 |
|
| 1409 |
|
|
|
|
| 1499 |
L["backend"],
|
| 1500 |
["generative", "retrieval"],
|
| 1501 |
format_func=lambda x: {
|
| 1502 |
+
"generative": "Generative (FLUX + CLAP retrieval)",
|
| 1503 |
"retrieval": "Retrieval (CLIP + CLAP index)",
|
| 1504 |
}[x],
|
| 1505 |
)
|
|
|
|
| 1544 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 1545 |
}
|
| 1546 |
if backend == "generative":
|
| 1547 |
+
img_info = "FLUX.1-schnell (free) via HF API"
|
| 1548 |
+
aud_info = "CLAP retrieval (audio gen not on free tier)"
|
| 1549 |
else:
|
| 1550 |
img_info = "CLIP retrieval (57 images)"
|
| 1551 |
aud_info = "CLAP retrieval (104 clips)"
|
|
|
|
| 1903 |
backend = ai.get("backend", "unknown")
|
| 1904 |
|
| 1905 |
if backend == "retrieval" and R.get("backend") == "generative":
|
| 1906 |
+
sim = ai.get("similarity", 0)
|
| 1907 |
+
if ai.get("generation_unavailable"):
|
| 1908 |
+
if kid_mode:
|
| 1909 |
+
msg = ("Soundo hat ein passendes Lied aus seiner Sammlung geholt!"
|
| 1910 |
+
if lang == "de" else
|
| 1911 |
+
"Soundo picked a matching sound from the library!")
|
| 1912 |
+
st.markdown(f'<div class="{warn_cls}">{msg}</div>',
|
| 1913 |
+
unsafe_allow_html=True)
|
| 1914 |
+
else:
|
| 1915 |
+
st.markdown(
|
| 1916 |
+
f'<div class="{warn_cls}">Audio generation not available on free tier '
|
| 1917 |
+
f'\u2014 using CLAP retrieval (sim={sim:.3f}).</div>',
|
| 1918 |
+
unsafe_allow_html=True)
|
| 1919 |
+
elif ai.get("credit_error"):
|
| 1920 |
st.markdown(
|
| 1921 |
f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
|
| 1922 |
f'using retrieval fallback.</div>',
|
| 1923 |
unsafe_allow_html=True)
|
| 1924 |
else:
|
|
|
|
| 1925 |
st.markdown(
|
| 1926 |
f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
|
| 1927 |
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|