Upload folder using huggingface_hub
Browse files- app.py +1198 -186
- src/coherence/calibration.py +249 -0
- src/coherence/cmsci_engine.py +536 -0
- src/coherence/gram_volume.py +124 -0
- src/coherence/negative_bank.py +212 -0
- src/config/settings.py +44 -0
- src/embeddings/prob_adapter_trainer.py +256 -0
- src/embeddings/probabilistic_adapter.py +216 -0
- src/embeddings/space_alignment.py +201 -0
app.py
CHANGED
|
@@ -180,29 +180,743 @@ section[data-testid="stSidebar"] > div:first-child { padding-top: 1.2rem; }
|
|
| 180 |
# Example prompts
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
EXAMPLE_PROMPTS = {
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
"
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
"
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
"
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
}
|
| 204 |
DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# ---------------------------------------------------------------------------
|
| 207 |
# Planning prompt template (same as src/planner/prompts/unified.txt)
|
| 208 |
# ---------------------------------------------------------------------------
|
|
@@ -306,22 +1020,128 @@ def get_inference_client():
|
|
| 306 |
return InferenceClient(token=token)
|
| 307 |
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
# ---------------------------------------------------------------------------
|
| 310 |
# HF Inference API helpers
|
| 311 |
# ---------------------------------------------------------------------------
|
| 312 |
|
| 313 |
-
|
|
|
|
| 314 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
"HuggingFaceH4/zephyr-7b-beta",
|
| 316 |
"microsoft/Phi-3-mini-4k-instruct",
|
| 317 |
-
"
|
| 318 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
|
| 321 |
-
"""Call HF Inference API chat completion, trying multiple models.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
client = get_inference_client()
|
| 323 |
last_error = None
|
|
|
|
|
|
|
| 324 |
for model_id in TEXT_GEN_MODELS:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
try:
|
| 326 |
response = client.chat_completion(
|
| 327 |
model=model_id,
|
|
@@ -337,9 +1157,15 @@ def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float =
|
|
| 337 |
return text
|
| 338 |
except Exception as e:
|
| 339 |
last_error = e
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
continue
|
| 342 |
-
|
|
|
|
|
|
|
| 343 |
|
| 344 |
|
| 345 |
def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
|
|
@@ -425,11 +1251,14 @@ def plan_extended(prompt: str) -> Optional[Any]:
|
|
| 425 |
# Generation / retrieval functions
|
| 426 |
# ---------------------------------------------------------------------------
|
| 427 |
|
| 428 |
-
# HF Inference API model IDs
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
| 430 |
AUDIO_GEN_MODELS = [
|
| 431 |
-
"
|
| 432 |
-
"
|
| 433 |
]
|
| 434 |
|
| 435 |
def gen_text(prompt: str, mode: str) -> dict:
|
|
@@ -487,25 +1316,44 @@ def gen_text(prompt: str, mode: str) -> dict:
|
|
| 487 |
|
| 488 |
|
| 489 |
def generate_image(prompt: str) -> dict:
|
| 490 |
-
"""Generate image via HF Inference API
|
| 491 |
client = get_inference_client()
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
|
| 505 |
def generate_audio(prompt: str) -> dict:
|
| 506 |
-
"""Generate audio via HF Inference API,
|
| 507 |
client = get_inference_client()
|
|
|
|
| 508 |
for model_id in AUDIO_GEN_MODELS:
|
|
|
|
|
|
|
|
|
|
| 509 |
try:
|
| 510 |
audio_bytes = client.text_to_audio(prompt, model=model_id)
|
| 511 |
suffix = ".flac" if "musicgen" in model_id else ".wav"
|
|
@@ -514,7 +1362,6 @@ def generate_audio(prompt: str) -> dict:
|
|
| 514 |
tmp.write(audio_bytes)
|
| 515 |
tmp.flush()
|
| 516 |
else:
|
| 517 |
-
# Some API versions return object with .read() or similar
|
| 518 |
tmp.write(bytes(audio_bytes))
|
| 519 |
tmp.flush()
|
| 520 |
model_name = model_id.split("/")[-1]
|
|
@@ -523,11 +1370,17 @@ def generate_audio(prompt: str) -> dict:
|
|
| 523 |
"model": model_name, "failed": False,
|
| 524 |
}
|
| 525 |
except Exception as e:
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
continue
|
| 528 |
-
# All generative models failed — fall back to retrieval
|
| 529 |
logger.warning("All audio generation models failed — falling back to retrieval")
|
| 530 |
-
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
|
| 533 |
def retrieve_image(prompt: str) -> dict:
|
|
@@ -599,31 +1452,36 @@ def main():
|
|
| 599 |
layout="wide",
|
| 600 |
initial_sidebar_state="expanded",
|
| 601 |
)
|
| 602 |
-
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
|
| 603 |
|
| 604 |
-
#
|
| 605 |
-
st.markdown(
|
| 606 |
-
'<div class="hero-wrap">'
|
| 607 |
-
'<div class="hero-title">Multimodal Coherence AI</div>'
|
| 608 |
-
'<div class="hero-sub">Generate semantically coherent <b>text + image + audio</b> bundles '
|
| 609 |
-
'and evaluate cross-modal alignment with the <b>MSCI</b> metric.</div>'
|
| 610 |
-
'</div>', unsafe_allow_html=True)
|
| 611 |
-
|
| 612 |
-
# Sidebar
|
| 613 |
with st.sidebar:
|
| 614 |
st.markdown("#### Configuration")
|
| 615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
backend = st.selectbox(
|
| 617 |
-
"
|
| 618 |
["generative", "retrieval"],
|
| 619 |
format_func=lambda x: {
|
| 620 |
-
"generative": "Generative (SDXL +
|
| 621 |
"retrieval": "Retrieval (CLIP + CLAP index)",
|
| 622 |
}[x],
|
| 623 |
)
|
| 624 |
|
| 625 |
mode = st.selectbox(
|
| 626 |
-
"
|
| 627 |
["direct", "planner", "council", "extended_prompt"],
|
| 628 |
format_func=lambda x: {
|
| 629 |
"direct": "Direct",
|
|
@@ -634,13 +1492,25 @@ def main():
|
|
| 634 |
)
|
| 635 |
|
| 636 |
st.divider()
|
| 637 |
-
st.markdown("####
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
st.divider()
|
| 646 |
mode_desc = {
|
|
@@ -650,35 +1520,57 @@ def main():
|
|
| 650 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 651 |
}
|
| 652 |
if backend == "generative":
|
| 653 |
-
img_info = "SDXL via HF API"
|
| 654 |
-
aud_info = "
|
| 655 |
else:
|
| 656 |
img_info = "CLIP retrieval (57 images)"
|
| 657 |
aud_info = "CLAP retrieval (104 clips)"
|
|
|
|
| 658 |
st.markdown(
|
| 659 |
f'<div class="sidebar-info">'
|
| 660 |
f'<b>Text</b> HF Inference API<br>'
|
| 661 |
f'<b>Planning</b> {mode_desc[mode]}<br>'
|
| 662 |
f'<b>Image</b> {img_info}<br>'
|
| 663 |
-
f'<b>Audio</b> {aud_info}<br><br>'
|
| 664 |
f'<b>Metric</b> MSCI = 0.45 × s<sub>t,i</sub> + 0.45 × s<sub>t,a</sub><br><br>'
|
| 665 |
f'<b>Models</b><br>'
|
| 666 |
f'CLIP ViT-B/32 (coherence eval)<br>'
|
| 667 |
f'CLAP HTSAT-unfused (coherence eval)'
|
| 668 |
f'</div>', unsafe_allow_html=True)
|
| 669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
# Prompt input
|
| 671 |
default_prompt = st.session_state.get("prompt_input", "")
|
| 672 |
prompt = st.text_area(
|
| 673 |
"Scene", value=default_prompt, height=80,
|
| 674 |
-
placeholder="
|
| 675 |
label_visibility="collapsed",
|
| 676 |
)
|
| 677 |
|
| 678 |
# Button + chips
|
| 679 |
bc1, bc2 = st.columns([1, 3])
|
| 680 |
with bc1:
|
| 681 |
-
go = st.button("
|
| 682 |
with bc2:
|
| 683 |
mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
|
| 684 |
mcls = "chip-amber" if mode != "direct" else "chip-purple"
|
|
@@ -687,27 +1579,45 @@ def main():
|
|
| 687 |
bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
|
| 688 |
else:
|
| 689 |
bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
st.markdown(
|
| 691 |
f'<div class="chip-row">'
|
| 692 |
f'{bchip}'
|
| 693 |
f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
|
| 694 |
f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
|
|
|
|
| 695 |
f'</div>', unsafe_allow_html=True)
|
| 696 |
|
| 697 |
# Welcome state
|
| 698 |
if not go and "last_result" not in st.session_state:
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
return
|
| 706 |
|
| 707 |
if go and prompt.strip():
|
| 708 |
-
st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend)
|
|
|
|
| 709 |
|
| 710 |
if "last_result" in st.session_state:
|
|
|
|
|
|
|
| 711 |
show_results(st.session_state["last_result"])
|
| 712 |
|
| 713 |
|
|
@@ -715,16 +1625,29 @@ def main():
|
|
| 715 |
# Pipeline
|
| 716 |
# ---------------------------------------------------------------------------
|
| 717 |
|
| 718 |
-
def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
|
| 719 |
-
R: dict = {"mode": mode, "backend": backend}
|
| 720 |
t_all = time.time()
|
| 721 |
|
| 722 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
|
| 724 |
with st.status(plan_label, expanded=True) as s:
|
| 725 |
t0 = time.time()
|
| 726 |
try:
|
| 727 |
-
R["text"] = gen_text(
|
| 728 |
R["t_text"] = time.time() - t0
|
| 729 |
has_plan = R["text"].get("plan") is not None
|
| 730 |
lbl = f"Text ready ({R['t_text']:.1f}s)"
|
|
@@ -733,14 +1656,20 @@ def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
|
|
| 733 |
s.update(label=lbl, state="complete")
|
| 734 |
except Exception as e:
|
| 735 |
s.update(label=f"Text failed: {e}", state="error")
|
| 736 |
-
R["text"] = {"text":
|
| 737 |
R["t_text"] = time.time() - t0
|
| 738 |
|
| 739 |
-
|
| 740 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
# 2) Image
|
| 743 |
-
img_label = "Generating image
|
| 744 |
with st.status(img_label, expanded=True) as s:
|
| 745 |
t0 = time.time()
|
| 746 |
try:
|
|
@@ -791,13 +1720,14 @@ def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
|
|
| 791 |
R["audio"] = None
|
| 792 |
R["t_aud"] = time.time() - t0
|
| 793 |
|
| 794 |
-
# 4) Coherence evaluation
|
| 795 |
with st.status("Evaluating coherence...", expanded=True) as s:
|
| 796 |
t0 = time.time()
|
| 797 |
try:
|
| 798 |
imgp = R.get("image", {}).get("path") if R.get("image") else None
|
| 799 |
audp = R.get("audio", {}).get("path") if R.get("audio") else None
|
| 800 |
-
R["
|
|
|
|
| 801 |
R["t_eval"] = time.time() - t0
|
| 802 |
msci = R["coherence"].get("scores", {}).get("msci")
|
| 803 |
s.update(label=f"MSCI = {msci:.4f} ({R['t_eval']:.1f}s)", state="complete")
|
|
@@ -821,23 +1751,52 @@ def show_results(R: dict):
|
|
| 821 |
msci = sc.get("msci")
|
| 822 |
st_i = sc.get("st_i")
|
| 823 |
st_a = sc.get("st_a")
|
|
|
|
|
|
|
| 824 |
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
|
| 835 |
# Timing strip
|
| 836 |
tt = R.get("t_total", 0)
|
| 837 |
sep = '<span class="t-sep">|</span>'
|
|
|
|
|
|
|
| 838 |
st.markdown(
|
| 839 |
-
f'<div class="
|
| 840 |
f'<span class="t-total">Total {tt:.1f}s</span>{sep}'
|
|
|
|
| 841 |
f'<span>Text {R.get("t_text", 0):.1f}s</span>{sep}'
|
| 842 |
f'<span>Image {R.get("t_img", 0):.1f}s</span>{sep}'
|
| 843 |
f'<span>Audio {R.get("t_aud", 0):.1f}s</span>{sep}'
|
|
@@ -846,141 +1805,194 @@ def show_results(R: dict):
|
|
| 846 |
|
| 847 |
st.markdown("---")
|
| 848 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
# Three columns: text | image | audio
|
| 850 |
ct, ci, ca = st.columns([1.15, 1, 0.85])
|
| 851 |
|
| 852 |
with ct:
|
| 853 |
-
st.markdown('<div class="
|
| 854 |
txt = R.get("text", {}).get("text", "")
|
| 855 |
text_err = R.get("text", {}).get("text_error")
|
| 856 |
if text_err:
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
with ci:
|
| 863 |
-
st.markdown('<div class="
|
| 864 |
ii = R.get("image")
|
| 865 |
if ii and ii.get("path"):
|
| 866 |
ip = Path(ii["path"])
|
| 867 |
backend = ii.get("backend", "unknown")
|
| 868 |
|
| 869 |
-
if backend == "retrieval" and
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
|
| 876 |
if ip.exists():
|
| 877 |
st.image(str(ip), use_container_width=True)
|
| 878 |
model = ii.get("model", "")
|
| 879 |
if backend == "generative":
|
| 880 |
-
|
|
|
|
|
|
|
| 881 |
else:
|
| 882 |
sim = ii.get("similarity", 0)
|
| 883 |
dom = ii.get("domain", "other")
|
| 884 |
ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
|
| 885 |
st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
|
| 886 |
else:
|
| 887 |
-
st.info("No image.")
|
| 888 |
|
| 889 |
with ca:
|
| 890 |
-
st.markdown('<div class="
|
| 891 |
ai = R.get("audio")
|
| 892 |
if ai and ai.get("path"):
|
| 893 |
ap = Path(ai["path"])
|
| 894 |
backend = ai.get("backend", "unknown")
|
| 895 |
|
| 896 |
-
if backend == "retrieval" and
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
|
| 903 |
if ap.exists():
|
| 904 |
st.audio(str(ap))
|
| 905 |
model = ai.get("model", "")
|
| 906 |
if backend == "generative":
|
| 907 |
-
|
|
|
|
|
|
|
| 908 |
else:
|
| 909 |
sim = ai.get("similarity", 0)
|
| 910 |
st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
|
| 911 |
else:
|
| 912 |
-
st.info("No audio.")
|
| 913 |
|
| 914 |
st.markdown("---")
|
| 915 |
|
| 916 |
-
# Expandable details
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
mode = R.get("mode", "direct")
|
| 943 |
-
if mode == "direct":
|
| 944 |
-
st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
|
| 945 |
else:
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
else:
|
| 963 |
-
st.write("No
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
st.markdown(bars, unsafe_allow_html=True)
|
| 976 |
else:
|
| 977 |
-
st.
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
st.write("No data.")
|
| 984 |
|
| 985 |
|
| 986 |
if __name__ == "__main__":
|
|
|
|
| 180 |
# Example prompts
|
| 181 |
# ---------------------------------------------------------------------------
|
| 182 |
EXAMPLE_PROMPTS = {
|
| 183 |
+
"en": {
|
| 184 |
+
"Nature": [
|
| 185 |
+
"A peaceful forest at dawn with birdsong and morning mist",
|
| 186 |
+
"A field of golden wheat under a warm summer sunset",
|
| 187 |
+
"A dense jungle with exotic birds calling from the canopy",
|
| 188 |
+
],
|
| 189 |
+
"Urban": [
|
| 190 |
+
"A bustling city street at night with neon lights and traffic",
|
| 191 |
+
"A quiet alley in an old town with distant footsteps echoing",
|
| 192 |
+
"A cafe terrace on a busy boulevard with clinking glasses",
|
| 193 |
+
],
|
| 194 |
+
"Water": [
|
| 195 |
+
"Ocean waves crashing on a sandy beach at sunset",
|
| 196 |
+
"Rain falling on a pond with ripples spreading across the surface",
|
| 197 |
+
"A mountain stream flowing over rocks through a pine forest",
|
| 198 |
+
],
|
| 199 |
+
"Mixed": [
|
| 200 |
+
"A lighthouse on a cliff during a thunderstorm at night",
|
| 201 |
+
"A bonfire on a beach with waves and guitar music at night",
|
| 202 |
+
"A train passing through countryside with distant church bells",
|
| 203 |
+
],
|
| 204 |
+
},
|
| 205 |
+
"de": {
|
| 206 |
+
"Natur": [
|
| 207 |
+
"Ein friedlicher Wald bei Sonnenaufgang mit Vogelgesang und Morgennebel",
|
| 208 |
+
"Ein goldenes Weizenfeld unter einem warmen Sommerabend",
|
| 209 |
+
"Ein dichter Dschungel mit exotischen V\u00f6geln im Bl\u00e4tterdach",
|
| 210 |
+
],
|
| 211 |
+
"Stadt": [
|
| 212 |
+
"Eine belebte Stra\u00dfe bei Nacht mit Neonlichtern und Verkehr",
|
| 213 |
+
"Eine ruhige Gasse in einer Altstadt mit fernen Schritten",
|
| 214 |
+
"Eine Caf\u00e9-Terrasse an einem belebten Boulevard mit klinkenden Gl\u00e4sern",
|
| 215 |
+
],
|
| 216 |
+
"Wasser": [
|
| 217 |
+
"Meereswellen am Sandstrand bei Sonnenuntergang",
|
| 218 |
+
"Regen f\u00e4llt auf einen Teich mit sich ausbreitenden Wellen",
|
| 219 |
+
"Ein Bergbach flie\u00dft \u00fcber Felsen durch einen Kiefernwald",
|
| 220 |
+
],
|
| 221 |
+
"Gemischt": [
|
| 222 |
+
"Ein Leuchtturm auf einer Klippe w\u00e4hrend eines Gewitters bei Nacht",
|
| 223 |
+
"Ein Lagerfeuer am Strand mit Wellen und Gitarrenmusik bei Nacht",
|
| 224 |
+
"Ein Zug f\u00e4hrt durch die Landschaft mit fernen Kirchenglocken",
|
| 225 |
+
],
|
| 226 |
+
},
|
| 227 |
}
|
| 228 |
DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
|
| 229 |
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
# Kid Mode — example prompts (German, fun themes for children)
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
KID_EXAMPLE_PROMPTS = {
|
| 234 |
+
"de": {
|
| 235 |
+
"\U0001f47e Abenteuer": [
|
| 236 |
+
"Pikachu in einem magischen Wald bei Sonnenuntergang",
|
| 237 |
+
"Ein Minecraft-Dorf auf einer Insel mitten im Ozean",
|
| 238 |
+
"Ein kleiner Drache fliegt \u00fcber eine Burg bei Nacht",
|
| 239 |
+
"Ein Weltraumabenteuer mit Raketen und bunten Planeten",
|
| 240 |
+
],
|
| 241 |
+
"\U0001f43e Tiere": [
|
| 242 |
+
"Ein freundlicher Hund rettet ein K\u00e4tzchen im Regen",
|
| 243 |
+
"Dinosaurier spielen Fu\u00dfball auf einer sonnigen Wiese",
|
| 244 |
+
"Ein Einhorn galoppiert \u00fcber einen leuchtenden Regenbogen",
|
| 245 |
+
"Pinguine machen eine Schneeballschlacht am S\u00fcdpol",
|
| 246 |
+
"Ein kleiner Fuchs entdeckt einen geheimen Garten",
|
| 247 |
+
],
|
| 248 |
+
"\u2728 Fantasie": [
|
| 249 |
+
"Ein Zauberer braut einen glitzernden Trank in einem Schloss",
|
| 250 |
+
"Eine Fee fliegt durch einen Wald voller leuchtender Pilze",
|
| 251 |
+
"Ein verzaubertes Baumhaus in den Wolken mit Regenbogenbr\u00fccke",
|
| 252 |
+
"Ein Roboter und ein Teddy gehen zusammen auf Schatzsuche",
|
| 253 |
+
"Ein magischer Unterwasserpalast mit sprechenden Fischen",
|
| 254 |
+
],
|
| 255 |
+
"\U0001f602 Lustig": [
|
| 256 |
+
"Eine Katze f\u00e4hrt Skateboard durch eine bunte Stadt",
|
| 257 |
+
"Aliens landen im Schulgarten und spielen Verstecken",
|
| 258 |
+
"Ein Elefant versucht sich auf einem Trampolin",
|
| 259 |
+
"Ein Schneemann isst Eis am Strand im Sommer",
|
| 260 |
+
"Monster unter dem Bett machen eine Pyjamaparty",
|
| 261 |
+
],
|
| 262 |
+
"\U0001f3ae Spielwelt": [
|
| 263 |
+
"Super Mario springt durch eine Welt aus S\u00fc\u00dfigkeiten",
|
| 264 |
+
"Ein Ritter k\u00e4mpft gegen einen freundlichen Drachen",
|
| 265 |
+
"Eine Unterwasser-Rennstrecke mit U-Booten und Delfinen",
|
| 266 |
+
"Ein Baumhaus-Dorf im Dschungel mit H\u00e4ngebr\u00fccken",
|
| 267 |
+
"Tiere bauen zusammen eine riesige Sandburg am Meer",
|
| 268 |
+
],
|
| 269 |
+
},
|
| 270 |
+
"en": {
|
| 271 |
+
"\U0001f47e Adventure": [
|
| 272 |
+
"Pikachu in a magical forest at sunset",
|
| 273 |
+
"A Minecraft village on an island in the middle of the ocean",
|
| 274 |
+
"A little dragon flying over a castle at night",
|
| 275 |
+
"A space adventure with rockets and colorful planets",
|
| 276 |
+
],
|
| 277 |
+
"\U0001f43e Animals": [
|
| 278 |
+
"A friendly dog rescuing a kitten in the rain",
|
| 279 |
+
"Dinosaurs playing football on a sunny meadow",
|
| 280 |
+
"A unicorn galloping over a glowing rainbow",
|
| 281 |
+
"Penguins having a snowball fight at the South Pole",
|
| 282 |
+
"A little fox discovering a secret garden",
|
| 283 |
+
],
|
| 284 |
+
"\u2728 Fantasy": [
|
| 285 |
+
"A wizard brewing a sparkling potion in a castle",
|
| 286 |
+
"A fairy flying through a forest of glowing mushrooms",
|
| 287 |
+
"An enchanted treehouse in the clouds with a rainbow bridge",
|
| 288 |
+
"A robot and a teddy bear going on a treasure hunt together",
|
| 289 |
+
"A magical underwater palace with talking fish",
|
| 290 |
+
],
|
| 291 |
+
"\U0001f602 Funny": [
|
| 292 |
+
"A cat riding a skateboard through a colorful city",
|
| 293 |
+
"Aliens landing in the school garden and playing hide and seek",
|
| 294 |
+
"An elephant trying to jump on a trampoline",
|
| 295 |
+
"A snowman eating ice cream at the beach in summer",
|
| 296 |
+
"Monsters under the bed having a pajama party",
|
| 297 |
+
],
|
| 298 |
+
"\U0001f3ae Game World": [
|
| 299 |
+
"Super Mario jumping through a world made of candy",
|
| 300 |
+
"A knight fighting a friendly dragon",
|
| 301 |
+
"An underwater race track with submarines and dolphins",
|
| 302 |
+
"A treehouse village in the jungle with rope bridges",
|
| 303 |
+
"Animals building a giant sandcastle at the beach",
|
| 304 |
+
],
|
| 305 |
+
},
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# Kid Mode — CSS theme (bright, bubbly, playful)
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
KID_CSS = """
|
| 312 |
+
<style>
|
| 313 |
+
/* ============================================================
|
| 314 |
+
KID MODE — Full theme override
|
| 315 |
+
============================================================ */
|
| 316 |
+
|
| 317 |
+
/* Kill the top gap */
|
| 318 |
+
.block-container { padding-top: 0.5rem !important; }
|
| 319 |
+
header[data-testid="stHeader"] { display: none !important; }
|
| 320 |
+
|
| 321 |
+
/* Force light colorful background on EVERYTHING */
|
| 322 |
+
.stApp, .stApp > div, .main, .main .block-container,
|
| 323 |
+
[data-testid="stAppViewContainer"], [data-testid="stAppViewBlockContainer"],
|
| 324 |
+
section.main, section.main > div {
|
| 325 |
+
background: linear-gradient(170deg, #dbeafe 0%, #fce7f3 35%, #fef3c7 65%, #dcfce7 100%) !important;
|
| 326 |
+
color: #1e293b !important;
|
| 327 |
+
}
|
| 328 |
+
/* Sidebar light theme */
|
| 329 |
+
section[data-testid="stSidebar"], section[data-testid="stSidebar"] > div {
|
| 330 |
+
background: linear-gradient(180deg, #ede9fe 0%, #fce7f3 100%) !important;
|
| 331 |
+
color: #1e293b !important;
|
| 332 |
+
}
|
| 333 |
+
section[data-testid="stSidebar"] label,
|
| 334 |
+
section[data-testid="stSidebar"] .stMarkdown,
|
| 335 |
+
section[data-testid="stSidebar"] span,
|
| 336 |
+
section[data-testid="stSidebar"] p {
|
| 337 |
+
color: #334155 !important;
|
| 338 |
+
}
|
| 339 |
+
/* Force dark text everywhere */
|
| 340 |
+
.stMarkdown, .stMarkdown p, .stMarkdown span, .stMarkdown div,
|
| 341 |
+
.stTextArea textarea, label, .stSelectbox label {
|
| 342 |
+
color: #1e293b !important;
|
| 343 |
+
}
|
| 344 |
+
.stTextArea textarea {
|
| 345 |
+
background: rgba(255,255,255,0.85) !important;
|
| 346 |
+
border: 2px solid #c4b5fd !important;
|
| 347 |
+
border-radius: 18px !important;
|
| 348 |
+
font-size: 1rem !important;
|
| 349 |
+
}
|
| 350 |
+
.stTextArea textarea:focus {
|
| 351 |
+
border-color: #8b5cf6 !important;
|
| 352 |
+
box-shadow: 0 0 0 4px rgba(139,92,246,0.15) !important;
|
| 353 |
+
}
|
| 354 |
+
/* Status containers */
|
| 355 |
+
[data-testid="stStatusWidget"] {
|
| 356 |
+
background: rgba(255,255,255,0.6) !important;
|
| 357 |
+
border-radius: 14px !important;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
/* Floating background elements */
|
| 361 |
+
.kid-bg {
|
| 362 |
+
position: fixed; top: 0; left: 0; width: 100%; height: 100%;
|
| 363 |
+
pointer-events: none; z-index: 0; overflow: hidden;
|
| 364 |
+
}
|
| 365 |
+
.kid-bg-item {
|
| 366 |
+
position: absolute; opacity: 0.15;
|
| 367 |
+
animation: kid-float linear infinite;
|
| 368 |
+
}
|
| 369 |
+
@keyframes kid-float {
|
| 370 |
+
0% { transform: translateY(105vh) rotate(0deg) scale(0.8); opacity: 0; }
|
| 371 |
+
8% { opacity: 0.35; }
|
| 372 |
+
92% { opacity: 0.35; }
|
| 373 |
+
100% { transform: translateY(-10vh) rotate(360deg) scale(1.1); opacity: 0; }
|
| 374 |
+
}
|
| 375 |
+
/* Twinkle for stars */
|
| 376 |
+
@keyframes kid-twinkle {
|
| 377 |
+
0%, 100% { opacity: 0.15; transform: scale(0.8); }
|
| 378 |
+
50% { opacity: 0.5; transform: scale(1.2); }
|
| 379 |
+
}
|
| 380 |
+
.kid-star-fixed {
|
| 381 |
+
position: absolute; pointer-events: none;
|
| 382 |
+
animation: kid-twinkle ease-in-out infinite;
|
| 383 |
+
}
|
| 384 |
+
/* Clouds */
|
| 385 |
+
.kid-cloud {
|
| 386 |
+
position: absolute; pointer-events: none; opacity: 0.18;
|
| 387 |
+
width: 120px; height: 50px; background: white;
|
| 388 |
+
border-radius: 50px; animation: kid-drift linear infinite;
|
| 389 |
+
}
|
| 390 |
+
.kid-cloud::before {
|
| 391 |
+
content: ''; position: absolute; background: white; border-radius: 50%;
|
| 392 |
+
width: 55px; height: 55px; top: -25px; left: 20px;
|
| 393 |
+
}
|
| 394 |
+
.kid-cloud::after {
|
| 395 |
+
content: ''; position: absolute; background: white; border-radius: 50%;
|
| 396 |
+
width: 40px; height: 40px; top: -18px; left: 55px;
|
| 397 |
+
}
|
| 398 |
+
@keyframes kid-drift {
|
| 399 |
+
0% { transform: translateX(-150px); }
|
| 400 |
+
100% { transform: translateX(calc(100vw + 150px)); }
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
/* Hero — big colorful title */
|
| 404 |
+
.kid-hero {
|
| 405 |
+
text-align: center; padding: 0.8rem 0 0.3rem; position: relative; z-index: 1;
|
| 406 |
+
}
|
| 407 |
+
.kid-hero-title {
|
| 408 |
+
font-size: 3.2rem; font-weight: 900; letter-spacing: -0.02em;
|
| 409 |
+
background: linear-gradient(135deg, #ec4899, #f97316, #eab308, #22c55e, #3b82f6, #8b5cf6);
|
| 410 |
+
background-size: 300% 300%;
|
| 411 |
+
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
|
| 412 |
+
animation: kid-gradient 4s ease infinite;
|
| 413 |
+
text-shadow: none;
|
| 414 |
+
}
|
| 415 |
+
@keyframes kid-gradient {
|
| 416 |
+
0% { background-position: 0% 50%; }
|
| 417 |
+
50% { background-position: 100% 50%; }
|
| 418 |
+
100% { background-position: 0% 50%; }
|
| 419 |
+
}
|
| 420 |
+
.kid-hero-sub {
|
| 421 |
+
font-size: 1.15rem; color: #475569; margin-top: 0.2rem; font-weight: 500;
|
| 422 |
+
}
|
| 423 |
+
.kid-hero-sub b { color: #7c3aed; }
|
| 424 |
+
|
| 425 |
+
/* Mascots — bigger, animated, with speech bubbles */
|
| 426 |
+
.kid-mascot-row {
|
| 427 |
+
display: flex; justify-content: center; gap: 2rem; margin: 0.8rem 0 0.5rem;
|
| 428 |
+
position: relative; z-index: 1;
|
| 429 |
+
}
|
| 430 |
+
.kid-mascot {
|
| 431 |
+
display: flex; flex-direction: column; align-items: center;
|
| 432 |
+
padding: 0.8rem 1.2rem 0.5rem; border-radius: 24px;
|
| 433 |
+
background: rgba(255,255,255,0.9);
|
| 434 |
+
border: 3px solid rgba(255,255,255,1);
|
| 435 |
+
box-shadow: 0 8px 30px rgba(0,0,0,0.08), 0 2px 8px rgba(139,92,246,0.1);
|
| 436 |
+
transition: transform 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
| 437 |
+
cursor: default; position: relative;
|
| 438 |
+
min-width: 105px;
|
| 439 |
+
}
|
| 440 |
+
.kid-mascot:hover {
|
| 441 |
+
transform: scale(1.12) rotate(-3deg);
|
| 442 |
+
box-shadow: 0 12px 40px rgba(139,92,246,0.25);
|
| 443 |
+
}
|
| 444 |
+
.kid-mascot svg { display: block; margin: 0 auto; }
|
| 445 |
+
.kid-mascot-name {
|
| 446 |
+
font-size: 0.9rem; font-weight: 800; margin-top: 0.15rem;
|
| 447 |
+
letter-spacing: 0.04em;
|
| 448 |
+
}
|
| 449 |
+
.kid-mascot:nth-child(1) .kid-mascot-name { color: #3b82f6; }
|
| 450 |
+
.kid-mascot:nth-child(2) .kid-mascot-name { color: #ec4899; }
|
| 451 |
+
.kid-mascot:nth-child(3) .kid-mascot-name { color: #f97316; }
|
| 452 |
+
/* Continuous gentle bounce */
|
| 453 |
+
.kid-mascot:nth-child(1) { animation: kid-bob 2s ease-in-out infinite; }
|
| 454 |
+
.kid-mascot:nth-child(2) { animation: kid-bob 2s ease-in-out 0.3s infinite; }
|
| 455 |
+
.kid-mascot:nth-child(3) { animation: kid-bob 2s ease-in-out 0.6s infinite; }
|
| 456 |
+
@keyframes kid-bob {
|
| 457 |
+
0%, 100% { transform: translateY(0); }
|
| 458 |
+
50% { transform: translateY(-6px); }
|
| 459 |
+
}
|
| 460 |
+
.kid-mascot:hover { animation: none; }
|
| 461 |
+
/* Speech bubble */
|
| 462 |
+
.kid-speech {
|
| 463 |
+
position: absolute; top: -32px; left: 50%; transform: translateX(-50%);
|
| 464 |
+
background: #fef3c7; color: #92400e; font-size: 0.65rem; font-weight: 700;
|
| 465 |
+
padding: 3px 10px; border-radius: 12px; white-space: nowrap;
|
| 466 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
|
| 467 |
+
opacity: 0; transition: opacity 0.2s;
|
| 468 |
+
}
|
| 469 |
+
.kid-speech::after {
|
| 470 |
+
content: ''; position: absolute; bottom: -5px; left: 50%; margin-left: -5px;
|
| 471 |
+
border-left: 5px solid transparent; border-right: 5px solid transparent;
|
| 472 |
+
border-top: 5px solid #fef3c7;
|
| 473 |
+
}
|
| 474 |
+
.kid-mascot:hover .kid-speech { opacity: 1; }
|
| 475 |
+
|
| 476 |
+
/* Score cards — kid version */
|
| 477 |
+
.kid-scores {
|
| 478 |
+
display: grid; grid-template-columns: repeat(4, 1fr);
|
| 479 |
+
gap: 0.8rem; margin: 0.6rem 0; position: relative; z-index: 1;
|
| 480 |
+
}
|
| 481 |
+
@media (max-width: 768px) { .kid-scores { grid-template-columns: repeat(2, 1fr); } }
|
| 482 |
+
.kid-sc {
|
| 483 |
+
border-radius: 22px; padding: 1.1rem 0.8rem; text-align: center;
|
| 484 |
+
background: rgba(255,255,255,0.85);
|
| 485 |
+
border: 2.5px solid rgba(255,255,255,1);
|
| 486 |
+
box-shadow: 0 6px 24px rgba(0,0,0,0.06);
|
| 487 |
+
position: relative; overflow: hidden;
|
| 488 |
+
animation: kid-pop 0.4s cubic-bezier(0.34, 1.56, 0.64, 1) both;
|
| 489 |
+
}
|
| 490 |
+
.kid-sc:nth-child(1) { animation-delay: 0s; }
|
| 491 |
+
.kid-sc:nth-child(2) { animation-delay: 0.1s; }
|
| 492 |
+
.kid-sc:nth-child(3) { animation-delay: 0.2s; }
|
| 493 |
+
.kid-sc:nth-child(4) { animation-delay: 0.3s; }
|
| 494 |
+
@keyframes kid-pop {
|
| 495 |
+
0% { transform: scale(0.7); opacity: 0; }
|
| 496 |
+
100% { transform: scale(1); opacity: 1; }
|
| 497 |
+
}
|
| 498 |
+
.kid-sc::before {
|
| 499 |
+
content: ''; position: absolute; top: 0; left: 0; right: 0; height: 5px;
|
| 500 |
+
border-radius: 22px 22px 0 0;
|
| 501 |
+
}
|
| 502 |
+
.kid-sc-great::before { background: linear-gradient(90deg, #22c55e, #06b6d4); }
|
| 503 |
+
.kid-sc-ok::before { background: linear-gradient(90deg, #f59e0b, #f97316); }
|
| 504 |
+
.kid-sc-low::before { background: linear-gradient(90deg, #ef4444, #ec4899); }
|
| 505 |
+
.kid-sc-main::before { background: linear-gradient(90deg, #8b5cf6, #ec4899, #f97316, #eab308); background-size: 200%; animation: kid-gradient 3s ease infinite; }
|
| 506 |
+
.kid-sc-lbl {
|
| 507 |
+
font-size: 0.72rem; font-weight: 800; color: #64748b;
|
| 508 |
+
text-transform: uppercase; letter-spacing: 0.06em;
|
| 509 |
+
}
|
| 510 |
+
.kid-sc-stars { font-size: 1.8rem; margin: 0.3rem 0; line-height: 1.1; }
|
| 511 |
+
.kid-sc-emoji { font-size: 2.4rem; margin: 0.15rem 0; }
|
| 512 |
+
.kid-sc-val {
|
| 513 |
+
font-size: 0.7rem; color: #94a3b8; font-family: 'JetBrains Mono', monospace;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
/* Verdict banner */
|
| 517 |
+
.kid-verdict {
|
| 518 |
+
text-align: center; font-size: 1.4rem; font-weight: 800;
|
| 519 |
+
color: #334155; margin: 0.4rem 0 0.6rem;
|
| 520 |
+
animation: kid-pop 0.5s cubic-bezier(0.34, 1.56, 0.64, 1) both;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
/* Section labels */
|
| 524 |
+
.kid-sec-label {
|
| 525 |
+
font-size: 0.85rem; font-weight: 900; letter-spacing: 0.06em;
|
| 526 |
+
text-transform: uppercase; color: #7c3aed !important;
|
| 527 |
+
padding-bottom: 0.35rem; border-bottom: 3px solid #c4b5fd;
|
| 528 |
+
margin-bottom: 0.6rem;
|
| 529 |
+
}
|
| 530 |
+
.kid-text-card {
|
| 531 |
+
border-radius: 20px; padding: 1.2rem 1.3rem;
|
| 532 |
+
background: rgba(255,255,255,0.8);
|
| 533 |
+
border: 2px solid rgba(255,255,255,1);
|
| 534 |
+
box-shadow: 0 4px 20px rgba(0,0,0,0.05);
|
| 535 |
+
font-size: 0.95rem; line-height: 1.8; color: #334155 !important;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
.kid-timing {
|
| 539 |
+
display: flex; gap: 0.5rem; flex-wrap: wrap; align-items: center;
|
| 540 |
+
padding: 0.45rem 0.9rem; border-radius: 16px;
|
| 541 |
+
background: rgba(255,255,255,0.6);
|
| 542 |
+
border: 2px solid rgba(255,255,255,0.9);
|
| 543 |
+
font-size: 0.72rem; color: #64748b !important; margin: 0.4rem 0;
|
| 544 |
+
}
|
| 545 |
+
.kid-timing span { color: #64748b !important; }
|
| 546 |
+
.kid-timing .t-total { color: #7c3aed !important; font-weight: 700; }
|
| 547 |
+
.kid-timing .t-sep { color: #cbd5e1 !important; }
|
| 548 |
+
|
| 549 |
+
/* Warn banner */
|
| 550 |
+
.kid-warn {
|
| 551 |
+
border-radius: 16px; padding: 0.8rem 1.1rem; margin-bottom: 0.6rem;
|
| 552 |
+
border-left: 4px solid #f97316; font-size: 0.85rem; color: #9a3412 !important;
|
| 553 |
+
background: rgba(255,237,213,0.7);
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/* Button override */
|
| 557 |
+
.stButton > button[kind="primary"] {
|
| 558 |
+
background: linear-gradient(135deg, #8b5cf6, #ec4899) !important;
|
| 559 |
+
border: none !important; border-radius: 16px !important;
|
| 560 |
+
font-weight: 800 !important; font-size: 1.05rem !important;
|
| 561 |
+
padding: 0.6rem 1.5rem !important;
|
| 562 |
+
box-shadow: 0 4px 15px rgba(139,92,246,0.3) !important;
|
| 563 |
+
transition: transform 0.2s, box-shadow 0.2s !important;
|
| 564 |
+
}
|
| 565 |
+
.stButton > button[kind="primary"]:hover {
|
| 566 |
+
transform: scale(1.03) !important;
|
| 567 |
+
box-shadow: 0 6px 25px rgba(139,92,246,0.4) !important;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
/* Divider */
|
| 571 |
+
hr { border-color: rgba(139,92,246,0.15) !important; }
|
| 572 |
+
</style>
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
# ---------------------------------------------------------------------------
|
| 576 |
+
# Kid Mode — mascot HTML, star ratings, emoji feedback
|
| 577 |
+
# ---------------------------------------------------------------------------
|
| 578 |
+
|
| 579 |
+
MASCOT_HTML = """
|
| 580 |
+
<!-- Rich floating background -->
|
| 581 |
+
<div class="kid-bg">
|
| 582 |
+
<!-- Wave 1: floating emoji rising (spread across page) -->
|
| 583 |
+
<div class="kid-bg-item" style="font-size:30px;left:2%;animation-duration:14s;">\u2b50</div>
|
| 584 |
+
<div class="kid-bg-item" style="font-size:24px;left:8%;animation-duration:18s;animation-delay:2s;">\U0001f98b</div>
|
| 585 |
+
<div class="kid-bg-item" style="font-size:26px;left:14%;animation-duration:16s;animation-delay:5s;">\U0001f49c</div>
|
| 586 |
+
<div class="kid-bg-item" style="font-size:20px;left:20%;animation-duration:22s;animation-delay:1s;">\U0001f680</div>
|
| 587 |
+
<div class="kid-bg-item" style="font-size:32px;left:26%;animation-duration:13s;animation-delay:3s;">\u2728</div>
|
| 588 |
+
<div class="kid-bg-item" style="font-size:22px;left:32%;animation-duration:19s;animation-delay:7s;">\U0001f338</div>
|
| 589 |
+
<div class="kid-bg-item" style="font-size:28px;left:38%;animation-duration:15s;animation-delay:4s;">\U0001f31f</div>
|
| 590 |
+
<div class="kid-bg-item" style="font-size:18px;left:44%;animation-duration:20s;animation-delay:0s;">\U0001f984</div>
|
| 591 |
+
<div class="kid-bg-item" style="font-size:26px;left:50%;animation-duration:17s;animation-delay:6s;">\U0001f308</div>
|
| 592 |
+
<div class="kid-bg-item" style="font-size:24px;left:56%;animation-duration:14s;animation-delay:2s;">\U0001f49b</div>
|
| 593 |
+
<div class="kid-bg-item" style="font-size:20px;left:62%;animation-duration:21s;animation-delay:8s;">\U0001f33c</div>
|
| 594 |
+
<div class="kid-bg-item" style="font-size:30px;left:68%;animation-duration:16s;animation-delay:1s;">\u2b50</div>
|
| 595 |
+
<div class="kid-bg-item" style="font-size:22px;left:74%;animation-duration:18s;animation-delay:5s;">\U0001f98b</div>
|
| 596 |
+
<div class="kid-bg-item" style="font-size:28px;left:80%;animation-duration:13s;animation-delay:3s;">\u2728</div>
|
| 597 |
+
<div class="kid-bg-item" style="font-size:24px;left:86%;animation-duration:20s;animation-delay:9s;">\U0001f49a</div>
|
| 598 |
+
<div class="kid-bg-item" style="font-size:18px;left:92%;animation-duration:15s;animation-delay:4s;">\U0001f30d</div>
|
| 599 |
+
<div class="kid-bg-item" style="font-size:26px;left:97%;animation-duration:17s;animation-delay:0s;">\U0001f680</div>
|
| 600 |
+
<!-- Wave 2: offset for constant density -->
|
| 601 |
+
<div class="kid-bg-item" style="font-size:22px;left:5%;animation-duration:19s;animation-delay:10s;">\U0001f33c</div>
|
| 602 |
+
<div class="kid-bg-item" style="font-size:28px;left:15%;animation-duration:15s;animation-delay:11s;">\U0001f49b</div>
|
| 603 |
+
<div class="kid-bg-item" style="font-size:18px;left:25%;animation-duration:21s;animation-delay:9s;">\U0001f984</div>
|
| 604 |
+
<div class="kid-bg-item" style="font-size:26px;left:35%;animation-duration:16s;animation-delay:12s;">\u2b50</div>
|
| 605 |
+
<div class="kid-bg-item" style="font-size:24px;left:45%;animation-duration:18s;animation-delay:8s;">\U0001f98b</div>
|
| 606 |
+
<div class="kid-bg-item" style="font-size:20px;left:55%;animation-duration:14s;animation-delay:13s;">\U0001f308</div>
|
| 607 |
+
<div class="kid-bg-item" style="font-size:30px;left:65%;animation-duration:20s;animation-delay:10s;">\u2728</div>
|
| 608 |
+
<div class="kid-bg-item" style="font-size:22px;left:75%;animation-duration:17s;animation-delay:11s;">\U0001f338</div>
|
| 609 |
+
<div class="kid-bg-item" style="font-size:26px;left:85%;animation-duration:13s;animation-delay:14s;">\U0001f49a</div>
|
| 610 |
+
<div class="kid-bg-item" style="font-size:24px;left:95%;animation-duration:19s;animation-delay:9s;">\U0001f31f</div>
|
| 611 |
+
<!-- Wave 3: more for richness -->
|
| 612 |
+
<div class="kid-bg-item" style="font-size:20px;left:10%;animation-duration:17s;animation-delay:15s;">\U0001f680</div>
|
| 613 |
+
<div class="kid-bg-item" style="font-size:26px;left:30%;animation-duration:14s;animation-delay:16s;">\U0001f338</div>
|
| 614 |
+
<div class="kid-bg-item" style="font-size:22px;left:50%;animation-duration:19s;animation-delay:14s;">\U0001f984</div>
|
| 615 |
+
<div class="kid-bg-item" style="font-size:28px;left:70%;animation-duration:15s;animation-delay:17s;">\U0001f49c</div>
|
| 616 |
+
<div class="kid-bg-item" style="font-size:24px;left:90%;animation-duration:18s;animation-delay:15s;">\U0001f33c</div>
|
| 617 |
+
<!-- Twinkling stars (fixed) -->
|
| 618 |
+
<div class="kid-star-fixed" style="font-size:18px;top:5%;left:8%;animation-duration:2.5s;">\u2b50</div>
|
| 619 |
+
<div class="kid-star-fixed" style="font-size:14px;top:12%;left:30%;animation-duration:3s;animation-delay:0.5s;">\u2b50</div>
|
| 620 |
+
<div class="kid-star-fixed" style="font-size:16px;top:8%;left:55%;animation-duration:2.8s;animation-delay:1s;">\u2b50</div>
|
| 621 |
+
<div class="kid-star-fixed" style="font-size:12px;top:15%;left:80%;animation-duration:3.5s;animation-delay:0.3s;">\u2b50</div>
|
| 622 |
+
<div class="kid-star-fixed" style="font-size:15px;top:35%;left:5%;animation-duration:4s;animation-delay:0.8s;">\u2b50</div>
|
| 623 |
+
<div class="kid-star-fixed" style="font-size:11px;top:50%;left:92%;animation-duration:3.2s;animation-delay:1.5s;">\u2b50</div>
|
| 624 |
+
<div class="kid-star-fixed" style="font-size:17px;top:65%;left:15%;animation-duration:2.6s;animation-delay:0.2s;">\u2b50</div>
|
| 625 |
+
<div class="kid-star-fixed" style="font-size:13px;top:75%;left:70%;animation-duration:3.8s;animation-delay:2s;">\u2b50</div>
|
| 626 |
+
<div class="kid-star-fixed" style="font-size:10px;top:88%;left:45%;animation-duration:3s;animation-delay:1.2s;">\u2b50</div>
|
| 627 |
+
<div class="kid-star-fixed" style="font-size:14px;top:42%;left:88%;animation-duration:2.4s;animation-delay:0.7s;">\u2b50</div>
|
| 628 |
+
<!-- Clouds -->
|
| 629 |
+
<div class="kid-cloud" style="top:3%;animation-duration:40s;"></div>
|
| 630 |
+
<div class="kid-cloud" style="top:20%;animation-duration:55s;animation-delay:12s;width:90px;height:38px;"></div>
|
| 631 |
+
<div class="kid-cloud" style="top:45%;animation-duration:48s;animation-delay:25s;width:100px;height:42px;"></div>
|
| 632 |
+
<div class="kid-cloud" style="top:65%;animation-duration:52s;animation-delay:8s;width:80px;height:34px;"></div>
|
| 633 |
+
<div class="kid-cloud" style="top:85%;animation-duration:44s;animation-delay:20s;"></div>
|
| 634 |
+
</div>
|
| 635 |
+
<!-- Corner characters: cute SVG creatures -->
|
| 636 |
+
<!-- Cat (bottom-left) -->
|
| 637 |
+
<div style="position:fixed;bottom:15px;left:260px;z-index:2;opacity:0.4;pointer-events:none;animation:kid-bob 3s ease-in-out infinite;">
|
| 638 |
+
<svg width="55" height="50" viewBox="0 0 55 50">
|
| 639 |
+
<polygon points="9,16 4,2 17,12" fill="#f97316"/>
|
| 640 |
+
<polygon points="46,16 51,2 39,12" fill="#f97316"/>
|
| 641 |
+
<ellipse cx="27" cy="27" rx="20" ry="16" fill="#fb923c"/>
|
| 642 |
+
<ellipse cx="20" cy="25" rx="2.5" ry="3" fill="#1e293b"/>
|
| 643 |
+
<ellipse cx="34" cy="25" rx="2.5" ry="3" fill="#1e293b"/>
|
| 644 |
+
<circle cx="21" cy="24" r="0.8" fill="white"/>
|
| 645 |
+
<circle cx="35" cy="24" r="0.8" fill="white"/>
|
| 646 |
+
<ellipse cx="27" cy="30" rx="2" ry="1.2" fill="#f472b6"/>
|
| 647 |
+
<path d="M24 32 Q27 35 30 32" stroke="#ea580c" stroke-width="1" fill="none"/>
|
| 648 |
+
<line x1="7" y1="27" x2="0" y2="25" stroke="#fdba74" stroke-width="1.2"/>
|
| 649 |
+
<line x1="7" y1="29" x2="0" y2="30" stroke="#fdba74" stroke-width="1.2"/>
|
| 650 |
+
<line x1="47" y1="27" x2="55" y2="25" stroke="#fdba74" stroke-width="1.2"/>
|
| 651 |
+
<line x1="47" y1="29" x2="55" y2="30" stroke="#fdba74" stroke-width="1.2"/>
|
| 652 |
+
<path d="M13 43 Q7 47 10 50" stroke="#fb923c" stroke-width="3.5" fill="none" stroke-linecap="round"/>
|
| 653 |
+
</svg></div>
|
| 654 |
+
<!-- Dog (bottom-right) -->
|
| 655 |
+
<div style="position:fixed;bottom:15px;right:25px;z-index:2;opacity:0.4;pointer-events:none;animation:kid-bob 3.5s ease-in-out 0.5s infinite;">
|
| 656 |
+
<svg width="55" height="50" viewBox="0 0 55 50">
|
| 657 |
+
<ellipse cx="10" cy="10" rx="9" ry="13" fill="#a16207" transform="rotate(-20,10,10)"/>
|
| 658 |
+
<ellipse cx="45" cy="10" rx="9" ry="13" fill="#a16207" transform="rotate(20,45,10)"/>
|
| 659 |
+
<circle cx="27" cy="25" r="18" fill="#d97706"/>
|
| 660 |
+
<ellipse cx="20" cy="22" rx="2.5" ry="3" fill="#1e293b"/>
|
| 661 |
+
<ellipse cx="34" cy="22" rx="2.5" ry="3" fill="#1e293b"/>
|
| 662 |
+
<circle cx="21" cy="21" r="0.8" fill="white"/>
|
| 663 |
+
<circle cx="35" cy="21" r="0.8" fill="white"/>
|
| 664 |
+
<ellipse cx="27" cy="29" rx="3.5" ry="2.5" fill="#1e293b"/>
|
| 665 |
+
<ellipse cx="27" cy="28" rx="2" ry="1.2" fill="#f472b6"/>
|
| 666 |
+
<path d="M22 33 Q27 38 32 33" stroke="#92400e" stroke-width="1.2" fill="none"/>
|
| 667 |
+
</svg></div>
|
| 668 |
+
<!-- Unicorn (top-right) -->
|
| 669 |
+
<div style="position:fixed;top:75px;right:25px;z-index:2;opacity:0.35;pointer-events:none;animation:kid-bob 4s ease-in-out 1s infinite;">
|
| 670 |
+
<svg width="50" height="55" viewBox="0 0 50 55">
|
| 671 |
+
<polygon points="25,0 22,15 28,15" fill="#fbbf24"/>
|
| 672 |
+
<circle cx="25" cy="25" r="14" fill="white" stroke="#e9d5ff" stroke-width="1"/>
|
| 673 |
+
<ellipse cx="19" cy="23" rx="2.5" ry="3" fill="#1e293b"/>
|
| 674 |
+
<ellipse cx="31" cy="23" rx="2.5" ry="3" fill="#1e293b"/>
|
| 675 |
+
<circle cx="20" cy="22" r="0.8" fill="white"/>
|
| 676 |
+
<circle cx="32" cy="22" r="0.8" fill="white"/>
|
| 677 |
+
<circle cx="14" cy="28" rx="3" fill="#fecdd3" opacity="0.5"/>
|
| 678 |
+
<circle cx="36" cy="28" rx="3" fill="#fecdd3" opacity="0.5"/>
|
| 679 |
+
<path d="M20 30 Q25 34 30 30" stroke="#ec4899" stroke-width="1.2" fill="none"/>
|
| 680 |
+
<path d="M11 16 Q5 10 7 18" stroke="#c4b5fd" stroke-width="2.5" fill="none" stroke-linecap="round"/>
|
| 681 |
+
<path d="M13 14 Q8 6 9 15" stroke="#fbcfe8" stroke-width="2" fill="none" stroke-linecap="round"/>
|
| 682 |
+
<path d="M39 16 Q45 10 43 18" stroke="#bfdbfe" stroke-width="2.5" fill="none" stroke-linecap="round"/>
|
| 683 |
+
<path d="M37 14 Q42 6 41 15" stroke="#fde68a" stroke-width="2" fill="none" stroke-linecap="round"/>
|
| 684 |
+
</svg></div>
|
| 685 |
+
<!-- Rocket (top-left past sidebar) -->
|
| 686 |
+
<div style="position:fixed;top:65px;left:260px;z-index:2;opacity:0.35;pointer-events:none;animation:kid-bob 3.2s ease-in-out 0.8s infinite;">
|
| 687 |
+
<svg width="35" height="55" viewBox="0 0 35 55">
|
| 688 |
+
<ellipse cx="17" cy="22" rx="10" ry="18" fill="#ef4444"/>
|
| 689 |
+
<ellipse cx="17" cy="22" rx="6.5" ry="12" fill="#fca5a5"/>
|
| 690 |
+
<circle cx="17" cy="19" r="4.5" fill="#dbeafe"/>
|
| 691 |
+
<circle cx="17" cy="19" r="2.5" fill="#3b82f6"/>
|
| 692 |
+
<polygon points="17,1 14,10 20,10" fill="#ef4444"/>
|
| 693 |
+
<polygon points="7,34 2,43 12,36" fill="#f97316"/>
|
| 694 |
+
<polygon points="27,34 32,43 22,36" fill="#f97316"/>
|
| 695 |
+
<ellipse cx="17" cy="40" rx="4" ry="3.5" fill="#fbbf24"/>
|
| 696 |
+
<ellipse cx="17" cy="44" rx="2.5" ry="5" fill="#fb923c" opacity="0.7"/>
|
| 697 |
+
<ellipse cx="17" cy="49" rx="1.5" ry="3.5" fill="#fbbf24" opacity="0.4"/>
|
| 698 |
+
</svg></div>
|
| 699 |
+
<!-- SVG Mascots -->
|
| 700 |
+
<div class="kid-mascot-row">
|
| 701 |
+
<div class="kid-mascot">
|
| 702 |
+
<div class="kid-speech">Ich schreibe!</div>
|
| 703 |
+
<svg width="70" height="75" viewBox="0 0 70 75">
|
| 704 |
+
<!-- Textino: cute blue robot -->
|
| 705 |
+
<!-- Antenna -->
|
| 706 |
+
<line x1="35" y1="8" x2="35" y2="0" stroke="#60a5fa" stroke-width="2.5" stroke-linecap="round"/>
|
| 707 |
+
<circle cx="35" cy="0" r="4" fill="#fbbf24"/>
|
| 708 |
+
<!-- Head -->
|
| 709 |
+
<rect x="10" y="8" width="50" height="32" rx="12" fill="#3b82f6"/>
|
| 710 |
+
<!-- Face screen -->
|
| 711 |
+
<rect x="15" y="13" width="40" height="22" rx="8" fill="#dbeafe"/>
|
| 712 |
+
<!-- Eyes -->
|
| 713 |
+
<circle cx="27" cy="23" r="5" fill="white"/>
|
| 714 |
+
<circle cx="43" cy="23" r="5" fill="white"/>
|
| 715 |
+
<circle cx="28" cy="23" r="3" fill="#1e293b"/>
|
| 716 |
+
<circle cx="44" cy="23" r="3" fill="#1e293b"/>
|
| 717 |
+
<!-- Eye shine -->
|
| 718 |
+
<circle cx="29" cy="22" r="1" fill="white"/>
|
| 719 |
+
<circle cx="45" cy="22" r="1" fill="white"/>
|
| 720 |
+
<!-- Smile -->
|
| 721 |
+
<path d="M25 29 Q35 35 45 29" stroke="#3b82f6" stroke-width="2" fill="none" stroke-linecap="round"/>
|
| 722 |
+
<!-- Body -->
|
| 723 |
+
<rect x="18" y="40" width="34" height="22" rx="8" fill="#60a5fa"/>
|
| 724 |
+
<!-- Arms -->
|
| 725 |
+
<rect x="5" y="42" width="13" height="8" rx="4" fill="#93c5fd"/>
|
| 726 |
+
<rect x="52" y="42" width="13" height="8" rx="4" fill="#93c5fd"/>
|
| 727 |
+
<!-- Pencil in right hand -->
|
| 728 |
+
<line x1="65" y1="42" x2="69" y2="32" stroke="#f97316" stroke-width="3" stroke-linecap="round"/>
|
| 729 |
+
<polygon points="69,32 67,28 71,28" fill="#fbbf24"/>
|
| 730 |
+
<!-- Belly button -->
|
| 731 |
+
<circle cx="35" cy="51" r="3" fill="#3b82f6"/>
|
| 732 |
+
<!-- Feet -->
|
| 733 |
+
<rect x="20" y="62" width="12" height="8" rx="4" fill="#3b82f6"/>
|
| 734 |
+
<rect x="38" y="62" width="12" height="8" rx="4" fill="#3b82f6"/>
|
| 735 |
+
</svg>
|
| 736 |
+
<div class="kid-mascot-name">Textino</div>
|
| 737 |
+
</div>
|
| 738 |
+
<div class="kid-mascot">
|
| 739 |
+
<div class="kid-speech">Ich male!</div>
|
| 740 |
+
<svg width="70" height="75" viewBox="0 0 70 75">
|
| 741 |
+
<!-- Pixela: cute pink artist character -->
|
| 742 |
+
<!-- Beret -->
|
| 743 |
+
<ellipse cx="35" cy="10" rx="22" ry="8" fill="#ec4899"/>
|
| 744 |
+
<circle cx="35" cy="5" r="5" fill="#f472b6"/>
|
| 745 |
+
<!-- Head -->
|
| 746 |
+
<circle cx="35" cy="25" r="20" fill="#fda4af"/>
|
| 747 |
+
<!-- Rosy cheeks -->
|
| 748 |
+
<circle cx="22" cy="29" r="5" fill="#fecdd3" opacity="0.7"/>
|
| 749 |
+
<circle cx="48" cy="29" r="5" fill="#fecdd3" opacity="0.7"/>
|
| 750 |
+
<!-- Eyes -->
|
| 751 |
+
<ellipse cx="27" cy="23" rx="4.5" ry="5" fill="white"/>
|
| 752 |
+
<ellipse cx="43" cy="23" rx="4.5" ry="5" fill="white"/>
|
| 753 |
+
<circle cx="28" cy="23" r="3" fill="#1e293b"/>
|
| 754 |
+
<circle cx="44" cy="23" r="3" fill="#1e293b"/>
|
| 755 |
+
<circle cx="29" cy="22" r="1" fill="white"/>
|
| 756 |
+
<circle cx="45" cy="22" r="1" fill="white"/>
|
| 757 |
+
<!-- Cat mouth -->
|
| 758 |
+
<path d="M30 31 L35 34 L40 31" stroke="#e11d48" stroke-width="1.5" fill="none" stroke-linecap="round"/>
|
| 759 |
+
<!-- Body -->
|
| 760 |
+
<rect x="20" y="45" width="30" height="18" rx="10" fill="#fb7185"/>
|
| 761 |
+
<!-- Arms -->
|
| 762 |
+
<rect x="7" y="47" width="13" height="7" rx="3.5" fill="#fda4af"/>
|
| 763 |
+
<rect x="50" y="47" width="13" height="7" rx="3.5" fill="#fda4af"/>
|
| 764 |
+
<!-- Paintbrush in right hand -->
|
| 765 |
+
<line x1="63" y1="47" x2="68" y2="35" stroke="#a16207" stroke-width="2.5" stroke-linecap="round"/>
|
| 766 |
+
<ellipse cx="68" cy="33" rx="4" ry="5" fill="#8b5cf6" transform="rotate(-15,68,33)"/>
|
| 767 |
+
<!-- Paint palette in left hand -->
|
| 768 |
+
<ellipse cx="4" cy="50" rx="8" ry="5" fill="#fde68a" transform="rotate(10,4,50)"/>
|
| 769 |
+
<circle cx="2" cy="48" r="2" fill="#ef4444"/>
|
| 770 |
+
<circle cx="6" cy="47" r="2" fill="#3b82f6"/>
|
| 771 |
+
<circle cx="4" cy="52" r="2" fill="#22c55e"/>
|
| 772 |
+
<!-- Feet -->
|
| 773 |
+
<ellipse cx="28" cy="67" rx="7" ry="5" fill="#ec4899"/>
|
| 774 |
+
<ellipse cx="42" cy="67" rx="7" ry="5" fill="#ec4899"/>
|
| 775 |
+
</svg>
|
| 776 |
+
<div class="kid-mascot-name">Pixela</div>
|
| 777 |
+
</div>
|
| 778 |
+
<div class="kid-mascot">
|
| 779 |
+
<div class="kid-speech">Ich spiele!</div>
|
| 780 |
+
<svg width="70" height="75" viewBox="0 0 70 75">
|
| 781 |
+
<!-- Soundo: cute orange music character -->
|
| 782 |
+
<!-- Headphones band -->
|
| 783 |
+
<path d="M12 25 Q12 5 35 5 Q58 5 58 25" stroke="#f97316" stroke-width="4" fill="none" stroke-linecap="round"/>
|
| 784 |
+
<!-- Headphone pads -->
|
| 785 |
+
<rect x="6" y="20" width="12" height="16" rx="6" fill="#f97316"/>
|
| 786 |
+
<rect x="52" y="20" width="12" height="16" rx="6" fill="#f97316"/>
|
| 787 |
+
<rect x="8" y="22" width="8" height="12" rx="4" fill="#fdba74"/>
|
| 788 |
+
<rect x="54" y="22" width="8" height="12" rx="4" fill="#fdba74"/>
|
| 789 |
+
<!-- Head -->
|
| 790 |
+
<circle cx="35" cy="28" r="18" fill="#fed7aa"/>
|
| 791 |
+
<!-- Eyes - happy closed -->
|
| 792 |
+
<path d="M24 26 Q28 22 32 26" stroke="#1e293b" stroke-width="2.5" fill="none" stroke-linecap="round"/>
|
| 793 |
+
<path d="M38 26 Q42 22 46 26" stroke="#1e293b" stroke-width="2.5" fill="none" stroke-linecap="round"/>
|
| 794 |
+
<!-- Big open smile -->
|
| 795 |
+
<path d="M25 33 Q35 42 45 33" stroke="#ea580c" stroke-width="2" fill="#fef3c7" stroke-linecap="round"/>
|
| 796 |
+
<!-- Body -->
|
| 797 |
+
<rect x="22" y="46" width="26" height="16" rx="8" fill="#fb923c"/>
|
| 798 |
+
<!-- Arms -->
|
| 799 |
+
<rect x="9" y="48" width="13" height="7" rx="3.5" fill="#fdba74"/>
|
| 800 |
+
<rect x="48" y="48" width="13" height="7" rx="3.5" fill="#fdba74"/>
|
| 801 |
+
<!-- Music notes floating -->
|
| 802 |
+
<text x="60" y="15" font-size="14" fill="#8b5cf6" opacity="0.8">\u266a</text>
|
| 803 |
+
<text x="4" y="12" font-size="11" fill="#ec4899" opacity="0.7">\u266b</text>
|
| 804 |
+
<text x="55" y="45" font-size="10" fill="#f97316" opacity="0.6">\u266a</text>
|
| 805 |
+
<!-- Feet -->
|
| 806 |
+
<ellipse cx="29" cy="66" rx="7" ry="5" fill="#f97316"/>
|
| 807 |
+
<ellipse cx="41" cy="66" rx="7" ry="5" fill="#f97316"/>
|
| 808 |
+
</svg>
|
| 809 |
+
<div class="kid-mascot-name">Soundo</div>
|
| 810 |
+
</div>
|
| 811 |
+
</div>
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def _kid_stars(v: Optional[float]) -> str:
|
| 816 |
+
"""Convert a 0-1 score to 1-5 star rating HTML."""
|
| 817 |
+
if v is None:
|
| 818 |
+
return "\u2b50" * 0
|
| 819 |
+
n = max(1, min(5, round(v * 10))) # 0.1→1 star, 0.5→5 stars
|
| 820 |
+
return "\u2b50" * n + "\u2606" * (5 - n) # filled + empty
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def _kid_emoji(v: Optional[float]) -> str:
|
| 824 |
+
"""Return emoji face based on coherence score."""
|
| 825 |
+
if v is None:
|
| 826 |
+
return "\U0001f914"
|
| 827 |
+
if v >= 0.45:
|
| 828 |
+
return "\U0001f929" # star-struck
|
| 829 |
+
if v >= 0.35:
|
| 830 |
+
return "\U0001f60a" # happy
|
| 831 |
+
if v >= 0.25:
|
| 832 |
+
return "\U0001f642" # slightly smiling
|
| 833 |
+
return "\U0001f61f" # worried
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def _kid_verdict(v: Optional[float], lang: str = "de") -> str:
|
| 837 |
+
"""Return kid-friendly verdict text."""
|
| 838 |
+
if v is None:
|
| 839 |
+
return "Hmm..." if lang == "de" else "Hmm..."
|
| 840 |
+
if lang == "de":
|
| 841 |
+
if v >= 0.45:
|
| 842 |
+
return "Super! Alles passt perfekt zusammen! \U0001f389"
|
| 843 |
+
if v >= 0.35:
|
| 844 |
+
return "Gut gemacht! Das passt ziemlich gut! \U0001f44d"
|
| 845 |
+
if v >= 0.25:
|
| 846 |
+
return "Geht so \u2014 ein bisschen passt es! \U0001f914"
|
| 847 |
+
return "Hmm, das passt noch nicht so gut \U0001f61e"
|
| 848 |
+
else:
|
| 849 |
+
if v >= 0.45:
|
| 850 |
+
return "Amazing! Everything fits perfectly together! \U0001f389"
|
| 851 |
+
if v >= 0.35:
|
| 852 |
+
return "Well done! That fits pretty well! \U0001f44d"
|
| 853 |
+
if v >= 0.25:
|
| 854 |
+
return "So-so \u2014 it fits a little bit! \U0001f914"
|
| 855 |
+
return "Hmm, that doesn't quite fit yet \U0001f61e"
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def kid_score_card(label: str, value: Optional[float], is_main: bool = False) -> str:
|
| 859 |
+
"""Kid-friendly score card with stars and emoji."""
|
| 860 |
+
cls = "kid-sc-main" if is_main else (
|
| 861 |
+
"kid-sc-great" if value and value >= 0.45 else
|
| 862 |
+
"kid-sc-ok" if value and value >= 0.30 else "kid-sc-low"
|
| 863 |
+
)
|
| 864 |
+
stars = _kid_stars(value)
|
| 865 |
+
emoji = _kid_emoji(value) if is_main else ""
|
| 866 |
+
val_str = f"{value:.3f}" if value is not None else "\u2014"
|
| 867 |
+
emoji_html = f'<div class="kid-sc-emoji">{emoji}</div>' if emoji else ""
|
| 868 |
+
return (
|
| 869 |
+
f'<div class="kid-sc {cls} kid-confetti">'
|
| 870 |
+
f'<div class="kid-sc-lbl">{label}</div>'
|
| 871 |
+
f'{emoji_html}'
|
| 872 |
+
f'<div class="kid-sc-stars">{stars}</div>'
|
| 873 |
+
f'<div class="kid-sc-val">{val_str}</div>'
|
| 874 |
+
f'</div>'
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
# Kid-mode UI labels
|
| 879 |
+
UI_LABELS_KID = {
|
| 880 |
+
"de": {
|
| 881 |
+
"hero_title": "Multimodale KI f\u00fcr Kids",
|
| 882 |
+
"hero_sub": "Beschreibe eine Szene und die KI erzeugt <b>Text + Bild + Audio</b> dazu!",
|
| 883 |
+
"config": "Einstellungen",
|
| 884 |
+
"backend": "Wie soll es erstellt werden?",
|
| 885 |
+
"planning": "Planungsmodus",
|
| 886 |
+
"language": "Sprache",
|
| 887 |
+
"examples": "Ideen zum Ausprobieren",
|
| 888 |
+
"scene_placeholder": "Beschreibe deine Szene hier... z.B. 'Ein Einhorn fliegt \u00fcber einen Regenbogen' \U0001f308",
|
| 889 |
+
"generate_btn": "\u2728 Los geht's!",
|
| 890 |
+
"welcome_text": "Beschreibe eine Szene und klicke auf <b>\u2728 Los geht's!</b>",
|
| 891 |
+
"welcome_hint": "oder w\u00e4hle eine Idee aus der Seitenleiste \U0001f449",
|
| 892 |
+
"scores_label": "\U0001f3af Wie gut passt alles zusammen?",
|
| 893 |
+
"gen_text_label": "\U0001f916 Textino schreibt...",
|
| 894 |
+
"gen_image_label": "\U0001f3a8 Pixela malt...",
|
| 895 |
+
"gen_audio_label": "\U0001f3b5 Soundo spielt...",
|
| 896 |
+
"translated_note": "Aus dem Deutschen \u00fcbersetzt",
|
| 897 |
+
"original_label": "Original (Deutsch)",
|
| 898 |
+
},
|
| 899 |
+
"en": {
|
| 900 |
+
"hero_title": "Multimodal AI for Kids",
|
| 901 |
+
"hero_sub": "Describe a scene and the AI creates <b>text + image + audio</b> for it!",
|
| 902 |
+
"config": "Settings",
|
| 903 |
+
"backend": "How should it be created?",
|
| 904 |
+
"planning": "Planning Mode",
|
| 905 |
+
"language": "Language",
|
| 906 |
+
"examples": "Ideas to Try",
|
| 907 |
+
"scene_placeholder": "Describe your scene here... e.g., 'A unicorn flying over a rainbow' \U0001f308",
|
| 908 |
+
"generate_btn": "\u2728 Let's Go!",
|
| 909 |
+
"welcome_text": "Describe a scene and click <b>\u2728 Let's Go!</b>",
|
| 910 |
+
"welcome_hint": "or pick an idea from the sidebar \U0001f449",
|
| 911 |
+
"scores_label": "\U0001f3af How well does everything fit together?",
|
| 912 |
+
"gen_text_label": "\U0001f916 Textino writes...",
|
| 913 |
+
"gen_image_label": "\U0001f3a8 Pixela paints...",
|
| 914 |
+
"gen_audio_label": "\U0001f3b5 Soundo plays...",
|
| 915 |
+
"translated_note": "Translated from German",
|
| 916 |
+
"original_label": "Original (German)",
|
| 917 |
+
},
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
# ---------------------------------------------------------------------------
|
| 921 |
# Planning prompt template (same as src/planner/prompts/unified.txt)
|
| 922 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1020 |
return InferenceClient(token=token)
|
| 1021 |
|
| 1022 |
|
| 1023 |
+
# ---------------------------------------------------------------------------
|
| 1024 |
+
# Translation (German <-> English)
|
| 1025 |
+
# ---------------------------------------------------------------------------
|
| 1026 |
+
|
| 1027 |
+
TRANSLATION_MODELS = {
|
| 1028 |
+
"de-en": "Helsinki-NLP/opus-mt-de-en",
|
| 1029 |
+
"en-de": "Helsinki-NLP/opus-mt-en-de",
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
def translate(text: str, direction: str) -> str:
|
| 1034 |
+
"""Translate text using HF Inference API. direction: 'de-en' or 'en-de'."""
|
| 1035 |
+
if not text or not text.strip():
|
| 1036 |
+
return text
|
| 1037 |
+
model_id = TRANSLATION_MODELS[direction]
|
| 1038 |
+
client = get_inference_client()
|
| 1039 |
+
try:
|
| 1040 |
+
result = client.translation(text, model=model_id)
|
| 1041 |
+
if isinstance(result, str):
|
| 1042 |
+
return result
|
| 1043 |
+
# huggingface_hub returns a TranslationOutput object
|
| 1044 |
+
return result.translation_text if hasattr(result, "translation_text") else str(result)
|
| 1045 |
+
except Exception as e:
|
| 1046 |
+
logger.warning("Translation (%s) failed: %s — returning original", direction, e)
|
| 1047 |
+
return text
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def translate_de_to_en(text: str) -> str:
|
| 1051 |
+
return translate(text, "de-en")
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def translate_en_to_de(text: str) -> str:
|
| 1055 |
+
return translate(text, "en-de")
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
# ---------------------------------------------------------------------------
|
| 1059 |
+
# UI labels (i18n)
|
| 1060 |
+
# ---------------------------------------------------------------------------
|
| 1061 |
+
|
| 1062 |
+
UI_LABELS = {
|
| 1063 |
+
"en": {
|
| 1064 |
+
"hero_title": "Multimodal Coherence AI",
|
| 1065 |
+
"hero_sub": 'Generate semantically coherent <b>text + image + audio</b> bundles '
|
| 1066 |
+
'and evaluate cross-modal alignment with the <b>MSCI</b> metric.',
|
| 1067 |
+
"config": "Configuration",
|
| 1068 |
+
"backend": "Backend",
|
| 1069 |
+
"planning": "Planning Mode",
|
| 1070 |
+
"language": "Language",
|
| 1071 |
+
"examples": "Examples",
|
| 1072 |
+
"scene_placeholder": "Describe a scene... e.g., 'A peaceful forest at dawn with birdsong and morning mist'",
|
| 1073 |
+
"generate_btn": "Generate Bundle",
|
| 1074 |
+
"welcome_text": 'Enter a scene description and click <b>Generate Bundle</b>',
|
| 1075 |
+
"welcome_hint": "or pick an example from the sidebar",
|
| 1076 |
+
"scores_label": "Coherence Scores",
|
| 1077 |
+
"gen_text_label": "Generated Text",
|
| 1078 |
+
"gen_image_label": "Generated Image",
|
| 1079 |
+
"gen_audio_label": "Generated Audio",
|
| 1080 |
+
"translated_note": "Translated from German",
|
| 1081 |
+
"original_label": "Original (German)",
|
| 1082 |
+
},
|
| 1083 |
+
"de": {
|
| 1084 |
+
"hero_title": "Multimodale Koh\u00e4renz-KI",
|
| 1085 |
+
"hero_sub": 'Erzeuge semantisch koh\u00e4rente <b>Text + Bild + Audio</b> B\u00fcndel '
|
| 1086 |
+
'und bewerte die modale \u00dcbereinstimmung mit der <b>MSCI</b>-Metrik.',
|
| 1087 |
+
"config": "Einstellungen",
|
| 1088 |
+
"backend": "Verfahren",
|
| 1089 |
+
"planning": "Planungsmodus",
|
| 1090 |
+
"language": "Sprache",
|
| 1091 |
+
"examples": "Beispiele",
|
| 1092 |
+
"scene_placeholder": "Beschreibe eine Szene... z.B. 'Ein friedlicher Wald bei Sonnenaufgang mit Vogelgesang'",
|
| 1093 |
+
"generate_btn": "B\u00fcndel erzeugen",
|
| 1094 |
+
"welcome_text": 'Beschreibe eine Szene und klicke auf <b>B\u00fcndel erzeugen</b>',
|
| 1095 |
+
"welcome_hint": "oder w\u00e4hle ein Beispiel aus der Seitenleiste",
|
| 1096 |
+
"scores_label": "Koh\u00e4renz-Bewertung",
|
| 1097 |
+
"gen_text_label": "Erzeugter Text",
|
| 1098 |
+
"gen_image_label": "Erzeugtes Bild",
|
| 1099 |
+
"gen_audio_label": "Erzeugtes Audio",
|
| 1100 |
+
"translated_note": "Aus dem Deutschen \u00fcbersetzt",
|
| 1101 |
+
"original_label": "Original (Deutsch)",
|
| 1102 |
+
},
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
# ---------------------------------------------------------------------------
|
| 1107 |
# HF Inference API helpers
|
| 1108 |
# ---------------------------------------------------------------------------
|
| 1109 |
|
| 1110 |
+
# Primary models (may consume credits via Inference Providers)
|
| 1111 |
+
TEXT_GEN_MODELS_PAID = [
|
| 1112 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 1113 |
+
"meta-llama/Llama-3.2-3B-Instruct",
|
| 1114 |
+
]
|
| 1115 |
+
# Free serverless models (rate-limited but no credit cost)
|
| 1116 |
+
TEXT_GEN_MODELS_FREE = [
|
| 1117 |
"HuggingFaceH4/zephyr-7b-beta",
|
| 1118 |
"microsoft/Phi-3-mini-4k-instruct",
|
| 1119 |
+
"google/gemma-2-2b-it",
|
| 1120 |
]
|
| 1121 |
+
# Combined: try free first, then paid
|
| 1122 |
+
TEXT_GEN_MODELS = TEXT_GEN_MODELS_FREE + TEXT_GEN_MODELS_PAID
|
| 1123 |
+
|
| 1124 |
+
def _is_credit_error(e: Exception) -> bool:
|
| 1125 |
+
"""Check if an exception is a 402 Payment Required (credits depleted)."""
|
| 1126 |
+
msg = str(e).lower()
|
| 1127 |
+
return "402" in msg or "payment required" in msg or "credit" in msg
|
| 1128 |
+
|
| 1129 |
|
| 1130 |
def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
|
| 1131 |
+
"""Call HF Inference API chat completion, trying multiple models.
|
| 1132 |
+
|
| 1133 |
+
Tries free serverless models first, then paid models.
|
| 1134 |
+
Skips paid models entirely if a 402 credit error is detected.
|
| 1135 |
+
"""
|
| 1136 |
client = get_inference_client()
|
| 1137 |
last_error = None
|
| 1138 |
+
credits_depleted = False
|
| 1139 |
+
|
| 1140 |
for model_id in TEXT_GEN_MODELS:
|
| 1141 |
+
# Skip paid models if we already know credits are gone
|
| 1142 |
+
if credits_depleted and model_id in TEXT_GEN_MODELS_PAID:
|
| 1143 |
+
logger.info("Skipping paid model %s (credits depleted)", model_id)
|
| 1144 |
+
continue
|
| 1145 |
try:
|
| 1146 |
response = client.chat_completion(
|
| 1147 |
model=model_id,
|
|
|
|
| 1157 |
return text
|
| 1158 |
except Exception as e:
|
| 1159 |
last_error = e
|
| 1160 |
+
if _is_credit_error(e):
|
| 1161 |
+
credits_depleted = True
|
| 1162 |
+
logger.warning("Chat model %s: credits depleted (402)", model_id)
|
| 1163 |
+
else:
|
| 1164 |
+
logger.warning("Chat model %s failed: %s", model_id, e)
|
| 1165 |
continue
|
| 1166 |
+
|
| 1167 |
+
detail = "Credit balance is depleted." if credits_depleted else f"Last error: {last_error}"
|
| 1168 |
+
raise RuntimeError(f"All text models failed. {detail}")
|
| 1169 |
|
| 1170 |
|
| 1171 |
def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
|
|
|
|
| 1251 |
# Generation / retrieval functions
|
| 1252 |
# ---------------------------------------------------------------------------
|
| 1253 |
|
| 1254 |
+
# HF Inference API model IDs — free models first, paid fallback
|
| 1255 |
+
IMAGE_GEN_MODELS = [
|
| 1256 |
+
"black-forest-labs/FLUX.1-schnell", # Free serverless
|
| 1257 |
+
"stabilityai/stable-diffusion-xl-base-1.0", # May need credits
|
| 1258 |
+
]
|
| 1259 |
AUDIO_GEN_MODELS = [
|
| 1260 |
+
"facebook/musicgen-small", # Free serverless
|
| 1261 |
+
"cvssp/audioldm2", # May need credits
|
| 1262 |
]
|
| 1263 |
|
| 1264 |
def gen_text(prompt: str, mode: str) -> dict:
|
|
|
|
| 1316 |
|
| 1317 |
|
| 1318 |
def generate_image(prompt: str) -> dict:
|
| 1319 |
+
"""Generate image via HF Inference API, trying free models first. Falls back to retrieval."""
|
| 1320 |
client = get_inference_client()
|
| 1321 |
+
credits_depleted = False
|
| 1322 |
+
for model_id in IMAGE_GEN_MODELS:
|
| 1323 |
+
if credits_depleted and model_id == "stabilityai/stable-diffusion-xl-base-1.0":
|
| 1324 |
+
logger.info("Skipping paid image model (credits depleted)")
|
| 1325 |
+
continue
|
| 1326 |
+
try:
|
| 1327 |
+
image = client.text_to_image(prompt, model=model_id)
|
| 1328 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
|
| 1329 |
+
image.save(tmp.name)
|
| 1330 |
+
model_name = model_id.split("/")[-1]
|
| 1331 |
+
return {
|
| 1332 |
+
"path": tmp.name, "backend": "generative",
|
| 1333 |
+
"model": model_name, "failed": False,
|
| 1334 |
+
}
|
| 1335 |
+
except Exception as e:
|
| 1336 |
+
if _is_credit_error(e):
|
| 1337 |
+
credits_depleted = True
|
| 1338 |
+
logger.warning("Image model %s: credits depleted (402)", model_id)
|
| 1339 |
+
else:
|
| 1340 |
+
logger.warning("Image gen with %s failed: %s", model_id, e)
|
| 1341 |
+
continue
|
| 1342 |
+
logger.warning("All image generation models failed — falling back to retrieval")
|
| 1343 |
+
result = retrieve_image(prompt)
|
| 1344 |
+
if credits_depleted:
|
| 1345 |
+
result["credit_error"] = True
|
| 1346 |
+
return result
|
| 1347 |
|
| 1348 |
|
| 1349 |
def generate_audio(prompt: str) -> dict:
|
| 1350 |
+
"""Generate audio via HF Inference API, trying free models first. Falls back to retrieval."""
|
| 1351 |
client = get_inference_client()
|
| 1352 |
+
credits_depleted = False
|
| 1353 |
for model_id in AUDIO_GEN_MODELS:
|
| 1354 |
+
if credits_depleted and model_id == "cvssp/audioldm2":
|
| 1355 |
+
logger.info("Skipping paid audio model (credits depleted)")
|
| 1356 |
+
continue
|
| 1357 |
try:
|
| 1358 |
audio_bytes = client.text_to_audio(prompt, model=model_id)
|
| 1359 |
suffix = ".flac" if "musicgen" in model_id else ".wav"
|
|
|
|
| 1362 |
tmp.write(audio_bytes)
|
| 1363 |
tmp.flush()
|
| 1364 |
else:
|
|
|
|
| 1365 |
tmp.write(bytes(audio_bytes))
|
| 1366 |
tmp.flush()
|
| 1367 |
model_name = model_id.split("/")[-1]
|
|
|
|
| 1370 |
"model": model_name, "failed": False,
|
| 1371 |
}
|
| 1372 |
except Exception as e:
|
| 1373 |
+
if _is_credit_error(e):
|
| 1374 |
+
credits_depleted = True
|
| 1375 |
+
logger.warning("Audio model %s: credits depleted (402)", model_id)
|
| 1376 |
+
else:
|
| 1377 |
+
logger.warning("Audio gen with %s failed: %s", model_id, e)
|
| 1378 |
continue
|
|
|
|
| 1379 |
logger.warning("All audio generation models failed — falling back to retrieval")
|
| 1380 |
+
result = retrieve_audio(prompt)
|
| 1381 |
+
if credits_depleted:
|
| 1382 |
+
result["credit_error"] = True
|
| 1383 |
+
return result
|
| 1384 |
|
| 1385 |
|
| 1386 |
def retrieve_image(prompt: str) -> dict:
|
|
|
|
| 1452 |
layout="wide",
|
| 1453 |
initial_sidebar_state="expanded",
|
| 1454 |
)
|
|
|
|
| 1455 |
|
| 1456 |
+
# Sidebar — settings first (needed for CSS choice)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1457 |
with st.sidebar:
|
| 1458 |
st.markdown("#### Configuration")
|
| 1459 |
|
| 1460 |
+
kid_mode = st.toggle("\U0001f476 Kid Mode", value=False)
|
| 1461 |
+
|
| 1462 |
+
lang = st.selectbox(
|
| 1463 |
+
"Language / Sprache",
|
| 1464 |
+
["en", "de"],
|
| 1465 |
+
format_func=lambda x: {"en": "English", "de": "Deutsch"}[x],
|
| 1466 |
+
)
|
| 1467 |
+
|
| 1468 |
+
# Select labels based on kid mode and language
|
| 1469 |
+
if kid_mode:
|
| 1470 |
+
L = UI_LABELS_KID.get(lang, UI_LABELS_KID["en"])
|
| 1471 |
+
else:
|
| 1472 |
+
L = UI_LABELS[lang]
|
| 1473 |
+
|
| 1474 |
backend = st.selectbox(
|
| 1475 |
+
L["backend"],
|
| 1476 |
["generative", "retrieval"],
|
| 1477 |
format_func=lambda x: {
|
| 1478 |
+
"generative": "Generative (FLUX/SDXL + MusicGen)",
|
| 1479 |
"retrieval": "Retrieval (CLIP + CLAP index)",
|
| 1480 |
}[x],
|
| 1481 |
)
|
| 1482 |
|
| 1483 |
mode = st.selectbox(
|
| 1484 |
+
L["planning"],
|
| 1485 |
["direct", "planner", "council", "extended_prompt"],
|
| 1486 |
format_func=lambda x: {
|
| 1487 |
"direct": "Direct",
|
|
|
|
| 1492 |
)
|
| 1493 |
|
| 1494 |
st.divider()
|
| 1495 |
+
st.markdown(f"#### {L['examples']}")
|
| 1496 |
+
|
| 1497 |
+
# Kid mode uses fun themed prompts; normal mode uses domain prompts
|
| 1498 |
+
if kid_mode:
|
| 1499 |
+
lang_examples = KID_EXAMPLE_PROMPTS.get(lang, KID_EXAMPLE_PROMPTS["en"])
|
| 1500 |
+
for dname, prompts in lang_examples.items():
|
| 1501 |
+
with st.expander(dname): # already has emoji in key
|
| 1502 |
+
for p in prompts:
|
| 1503 |
+
if st.button(p, key=f"ex_{hash(p)}", use_container_width=True):
|
| 1504 |
+
st.session_state["prompt_input"] = p
|
| 1505 |
+
else:
|
| 1506 |
+
lang_examples = EXAMPLE_PROMPTS.get(lang, EXAMPLE_PROMPTS["en"])
|
| 1507 |
+
domain_icons_de = {"natur": "\U0001f33f", "stadt": "\U0001f3d9\ufe0f", "wasser": "\U0001f30a", "gemischt": "\U0001f310"}
|
| 1508 |
+
for dname, prompts in lang_examples.items():
|
| 1509 |
+
icon = DOMAIN_ICONS.get(dname.lower(), domain_icons_de.get(dname.lower(), "\U0001f4cd"))
|
| 1510 |
+
with st.expander(f"{icon} {dname}"):
|
| 1511 |
+
for p in prompts:
|
| 1512 |
+
if st.button(p, key=f"ex_{hash(p)}", use_container_width=True):
|
| 1513 |
+
st.session_state["prompt_input"] = p
|
| 1514 |
|
| 1515 |
st.divider()
|
| 1516 |
mode_desc = {
|
|
|
|
| 1520 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 1521 |
}
|
| 1522 |
if backend == "generative":
|
| 1523 |
+
img_info = "FLUX.1-schnell / SDXL via HF API"
|
| 1524 |
+
aud_info = "MusicGen / AudioLDM2 via HF API"
|
| 1525 |
else:
|
| 1526 |
img_info = "CLIP retrieval (57 images)"
|
| 1527 |
aud_info = "CLAP retrieval (104 clips)"
|
| 1528 |
+
trans_info = "<br><b>Translation</b> opus-mt-de-en / en-de" if lang == "de" else ""
|
| 1529 |
st.markdown(
|
| 1530 |
f'<div class="sidebar-info">'
|
| 1531 |
f'<b>Text</b> HF Inference API<br>'
|
| 1532 |
f'<b>Planning</b> {mode_desc[mode]}<br>'
|
| 1533 |
f'<b>Image</b> {img_info}<br>'
|
| 1534 |
+
f'<b>Audio</b> {aud_info}{trans_info}<br><br>'
|
| 1535 |
f'<b>Metric</b> MSCI = 0.45 × s<sub>t,i</sub> + 0.45 × s<sub>t,a</sub><br><br>'
|
| 1536 |
f'<b>Models</b><br>'
|
| 1537 |
f'CLIP ViT-B/32 (coherence eval)<br>'
|
| 1538 |
f'CLAP HTSAT-unfused (coherence eval)'
|
| 1539 |
f'</div>', unsafe_allow_html=True)
|
| 1540 |
|
| 1541 |
+
# Apply CSS based on mode
|
| 1542 |
+
if kid_mode:
|
| 1543 |
+
st.markdown(KID_CSS, unsafe_allow_html=True) # kid theme (includes all needed overrides)
|
| 1544 |
+
else:
|
| 1545 |
+
st.markdown(CUSTOM_CSS, unsafe_allow_html=True) # professional dark theme
|
| 1546 |
+
|
| 1547 |
+
# Hero
|
| 1548 |
+
if kid_mode:
|
| 1549 |
+
st.markdown(
|
| 1550 |
+
f'<div class="kid-hero">'
|
| 1551 |
+
f'<div class="kid-hero-title">{L["hero_title"]}</div>'
|
| 1552 |
+
f'<div class="kid-hero-sub">{L["hero_sub"]}</div>'
|
| 1553 |
+
f'</div>', unsafe_allow_html=True)
|
| 1554 |
+
st.markdown(MASCOT_HTML, unsafe_allow_html=True)
|
| 1555 |
+
else:
|
| 1556 |
+
st.markdown(
|
| 1557 |
+
f'<div class="hero-wrap">'
|
| 1558 |
+
f'<div class="hero-title">{L["hero_title"]}</div>'
|
| 1559 |
+
f'<div class="hero-sub">{L["hero_sub"]}</div>'
|
| 1560 |
+
f'</div>', unsafe_allow_html=True)
|
| 1561 |
+
|
| 1562 |
# Prompt input
|
| 1563 |
default_prompt = st.session_state.get("prompt_input", "")
|
| 1564 |
prompt = st.text_area(
|
| 1565 |
"Scene", value=default_prompt, height=80,
|
| 1566 |
+
placeholder=L["scene_placeholder"],
|
| 1567 |
label_visibility="collapsed",
|
| 1568 |
)
|
| 1569 |
|
| 1570 |
# Button + chips
|
| 1571 |
bc1, bc2 = st.columns([1, 3])
|
| 1572 |
with bc1:
|
| 1573 |
+
go = st.button(L["generate_btn"], type="primary", use_container_width=True, disabled=not prompt.strip())
|
| 1574 |
with bc2:
|
| 1575 |
mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
|
| 1576 |
mcls = "chip-amber" if mode != "direct" else "chip-purple"
|
|
|
|
| 1579 |
bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
|
| 1580 |
else:
|
| 1581 |
bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
|
| 1582 |
+
lang_chip = ""
|
| 1583 |
+
if lang == "de":
|
| 1584 |
+
lang_chip = '<span class="chip chip-amber"><span class="chip-dot chip-dot-amber"></span>DE \u2192 EN</span>'
|
| 1585 |
+
kid_chip = ""
|
| 1586 |
+
if kid_mode:
|
| 1587 |
+
kid_chip = '<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>\U0001f476 Kid</span>'
|
| 1588 |
st.markdown(
|
| 1589 |
f'<div class="chip-row">'
|
| 1590 |
f'{bchip}'
|
| 1591 |
f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
|
| 1592 |
f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
|
| 1593 |
+
f'{lang_chip}{kid_chip}'
|
| 1594 |
f'</div>', unsafe_allow_html=True)
|
| 1595 |
|
| 1596 |
# Welcome state
|
| 1597 |
if not go and "last_result" not in st.session_state:
|
| 1598 |
+
if kid_mode:
|
| 1599 |
+
st.markdown(
|
| 1600 |
+
f'<div class="welcome" style="background:rgba(255,255,255,0.5);border-radius:24px;padding:3rem 2rem;">'
|
| 1601 |
+
f'<div class="welcome-icons">\U0001f916\u2728\U0001f3a8\u2728\U0001f3b5</div>'
|
| 1602 |
+
f'<div class="welcome-text" style="color:#334155;">{L["welcome_text"]}</div>'
|
| 1603 |
+
f'<div class="welcome-hint" style="color:#64748b;">{L["welcome_hint"]}</div>'
|
| 1604 |
+
f'</div>', unsafe_allow_html=True)
|
| 1605 |
+
else:
|
| 1606 |
+
st.markdown(
|
| 1607 |
+
f'<div class="welcome">'
|
| 1608 |
+
f'<div class="welcome-icons">\U0001f3a8 \U0001f5bc\ufe0f \U0001f50a</div>'
|
| 1609 |
+
f'<div class="welcome-text">{L["welcome_text"]}</div>'
|
| 1610 |
+
f'<div class="welcome-hint">{L["welcome_hint"]}</div>'
|
| 1611 |
+
f'</div>', unsafe_allow_html=True)
|
| 1612 |
return
|
| 1613 |
|
| 1614 |
if go and prompt.strip():
|
| 1615 |
+
st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend, lang)
|
| 1616 |
+
st.session_state["last_result"]["kid_mode"] = kid_mode
|
| 1617 |
|
| 1618 |
if "last_result" in st.session_state:
|
| 1619 |
+
# Update kid_mode in case user toggled it after generation
|
| 1620 |
+
st.session_state["last_result"]["kid_mode"] = kid_mode
|
| 1621 |
show_results(st.session_state["last_result"])
|
| 1622 |
|
| 1623 |
|
|
|
|
| 1625 |
# Pipeline
|
| 1626 |
# ---------------------------------------------------------------------------
|
| 1627 |
|
| 1628 |
+
def run_pipeline(prompt: str, mode: str, backend: str = "generative", lang: str = "en") -> dict:
|
| 1629 |
+
R: dict = {"mode": mode, "backend": backend, "lang": lang, "original_prompt": prompt}
|
| 1630 |
t_all = time.time()
|
| 1631 |
|
| 1632 |
+
# 0) Translate German → English if needed
|
| 1633 |
+
en_prompt = prompt
|
| 1634 |
+
if lang == "de":
|
| 1635 |
+
with st.status("\u00dcbersetze ins Englische...", expanded=True) as s:
|
| 1636 |
+
t0 = time.time()
|
| 1637 |
+
en_prompt = translate_de_to_en(prompt)
|
| 1638 |
+
t_trans = time.time() - t0
|
| 1639 |
+
R["t_translate"] = t_trans
|
| 1640 |
+
R["en_prompt"] = en_prompt
|
| 1641 |
+
s.update(label=f"Translated ({t_trans:.1f}s): {en_prompt[:80]}...", state="complete")
|
| 1642 |
+
else:
|
| 1643 |
+
R["en_prompt"] = prompt
|
| 1644 |
+
|
| 1645 |
+
# 1) Text + Planning (always in English for CLIP/CLAP)
|
| 1646 |
plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
|
| 1647 |
with st.status(plan_label, expanded=True) as s:
|
| 1648 |
t0 = time.time()
|
| 1649 |
try:
|
| 1650 |
+
R["text"] = gen_text(en_prompt, mode)
|
| 1651 |
R["t_text"] = time.time() - t0
|
| 1652 |
has_plan = R["text"].get("plan") is not None
|
| 1653 |
lbl = f"Text ready ({R['t_text']:.1f}s)"
|
|
|
|
| 1656 |
s.update(label=lbl, state="complete")
|
| 1657 |
except Exception as e:
|
| 1658 |
s.update(label=f"Text failed: {e}", state="error")
|
| 1659 |
+
R["text"] = {"text": en_prompt, "image_prompt": en_prompt, "audio_prompt": en_prompt}
|
| 1660 |
R["t_text"] = time.time() - t0
|
| 1661 |
|
| 1662 |
+
# Translate generated text back to German for display
|
| 1663 |
+
if lang == "de":
|
| 1664 |
+
en_text = R["text"].get("text", "")
|
| 1665 |
+
R["text"]["text_en"] = en_text
|
| 1666 |
+
R["text"]["text"] = translate_en_to_de(en_text)
|
| 1667 |
+
|
| 1668 |
+
ip = R["text"].get("image_prompt", en_prompt)
|
| 1669 |
+
ap = R["text"].get("audio_prompt", en_prompt)
|
| 1670 |
|
| 1671 |
# 2) Image
|
| 1672 |
+
img_label = "Generating image..." if backend == "generative" else "Retrieving image..."
|
| 1673 |
with st.status(img_label, expanded=True) as s:
|
| 1674 |
t0 = time.time()
|
| 1675 |
try:
|
|
|
|
| 1720 |
R["audio"] = None
|
| 1721 |
R["t_aud"] = time.time() - t0
|
| 1722 |
|
| 1723 |
+
# 4) Coherence evaluation (always use English text for CLIP/CLAP)
|
| 1724 |
with st.status("Evaluating coherence...", expanded=True) as s:
|
| 1725 |
t0 = time.time()
|
| 1726 |
try:
|
| 1727 |
imgp = R.get("image", {}).get("path") if R.get("image") else None
|
| 1728 |
audp = R.get("audio", {}).get("path") if R.get("audio") else None
|
| 1729 |
+
eval_text = R["text"].get("text_en", R["text"]["text"]) # English for CLIP/CLAP
|
| 1730 |
+
R["coherence"] = eval_coherence(eval_text, imgp, audp)
|
| 1731 |
R["t_eval"] = time.time() - t0
|
| 1732 |
msci = R["coherence"].get("scores", {}).get("msci")
|
| 1733 |
s.update(label=f"MSCI = {msci:.4f} ({R['t_eval']:.1f}s)", state="complete")
|
|
|
|
| 1751 |
msci = sc.get("msci")
|
| 1752 |
st_i = sc.get("st_i")
|
| 1753 |
st_a = sc.get("st_a")
|
| 1754 |
+
lang = R.get("lang", "en")
|
| 1755 |
+
kid_mode = R.get("kid_mode", False)
|
| 1756 |
|
| 1757 |
+
if kid_mode:
|
| 1758 |
+
L = UI_LABELS_KID.get(lang, UI_LABELS_KID["en"])
|
| 1759 |
+
else:
|
| 1760 |
+
L = UI_LABELS.get(lang, UI_LABELS["en"])
|
| 1761 |
+
|
| 1762 |
+
# Warn banner CSS class
|
| 1763 |
+
warn_cls = "kid-warn" if kid_mode else "warn-banner"
|
| 1764 |
+
|
| 1765 |
+
# --- Score cards ---
|
| 1766 |
+
if kid_mode:
|
| 1767 |
+
st.markdown(f'<div class="kid-sec-label">{L["scores_label"]}</div>', unsafe_allow_html=True)
|
| 1768 |
+
# Kid verdict banner
|
| 1769 |
+
verdict = _kid_verdict(msci, lang)
|
| 1770 |
+
st.markdown(f'<div class="kid-verdict">{verdict}</div>', unsafe_allow_html=True)
|
| 1771 |
+
# Balloons for high coherence!
|
| 1772 |
+
if msci is not None and msci >= 0.40:
|
| 1773 |
+
st.balloons()
|
| 1774 |
+
cards = (
|
| 1775 |
+
kid_score_card("\U0001f3af Gesamt" if lang == "de" else "\U0001f3af Overall", msci, is_main=True)
|
| 1776 |
+
+ kid_score_card("\U0001f5bc\ufe0f Text \u2192 Bild" if lang == "de" else "\U0001f5bc\ufe0f Text \u2192 Image", st_i)
|
| 1777 |
+
+ kid_score_card("\U0001f50a Text \u2192 Ton" if lang == "de" else "\U0001f50a Text \u2192 Audio", st_a)
|
| 1778 |
+
+ kid_score_card("\U0001f31f Sterne" if lang == "de" else "\U0001f31f Stars", msci)
|
| 1779 |
+
)
|
| 1780 |
+
st.markdown(f'<div class="kid-scores">{cards}</div>', unsafe_allow_html=True)
|
| 1781 |
+
else:
|
| 1782 |
+
st.markdown(f'<div class="sec-label">{L["scores_label"]}</div>', unsafe_allow_html=True)
|
| 1783 |
+
cards = (
|
| 1784 |
+
score_card_html("MSCI (Overall)", msci)
|
| 1785 |
+
+ score_card_html("Text \u2192 Image", st_i)
|
| 1786 |
+
+ score_card_html("Text \u2192 Audio", st_a)
|
| 1787 |
+
+ score_card_html("Classification", msci, is_class=True)
|
| 1788 |
+
)
|
| 1789 |
+
st.markdown(f'<div class="scores-grid">{cards}</div>', unsafe_allow_html=True)
|
| 1790 |
|
| 1791 |
# Timing strip
|
| 1792 |
tt = R.get("t_total", 0)
|
| 1793 |
sep = '<span class="t-sep">|</span>'
|
| 1794 |
+
trans_timing = f'{sep}<span>Translate {R.get("t_translate", 0):.1f}s</span>' if lang == "de" else ""
|
| 1795 |
+
timing_cls = "kid-timing" if kid_mode else "timing"
|
| 1796 |
st.markdown(
|
| 1797 |
+
f'<div class="{timing_cls}">'
|
| 1798 |
f'<span class="t-total">Total {tt:.1f}s</span>{sep}'
|
| 1799 |
+
f'{trans_timing}'
|
| 1800 |
f'<span>Text {R.get("t_text", 0):.1f}s</span>{sep}'
|
| 1801 |
f'<span>Image {R.get("t_img", 0):.1f}s</span>{sep}'
|
| 1802 |
f'<span>Audio {R.get("t_aud", 0):.1f}s</span>{sep}'
|
|
|
|
| 1805 |
|
| 1806 |
st.markdown("---")
|
| 1807 |
|
| 1808 |
+
# CSS class helpers for kid/normal mode
|
| 1809 |
+
sec_cls = "kid-sec-label" if kid_mode else "sec-label"
|
| 1810 |
+
text_cls = "kid-text-card" if kid_mode else "text-card"
|
| 1811 |
+
|
| 1812 |
# Three columns: text | image | audio
|
| 1813 |
ct, ci, ca = st.columns([1.15, 1, 0.85])
|
| 1814 |
|
| 1815 |
with ct:
|
| 1816 |
+
st.markdown(f'<div class="{sec_cls}">{L["gen_text_label"]}</div>', unsafe_allow_html=True)
|
| 1817 |
txt = R.get("text", {}).get("text", "")
|
| 1818 |
text_err = R.get("text", {}).get("text_error")
|
| 1819 |
if text_err:
|
| 1820 |
+
if "credit" in text_err.lower() or "402" in text_err:
|
| 1821 |
+
st.markdown(
|
| 1822 |
+
f'<div class="{warn_cls}"><b>Text gen failed</b> — '
|
| 1823 |
+
f'HF credits depleted. Add credits at huggingface.co/settings/billing '
|
| 1824 |
+
f'or wait for free-tier reset.</div>',
|
| 1825 |
+
unsafe_allow_html=True)
|
| 1826 |
+
else:
|
| 1827 |
+
st.markdown(
|
| 1828 |
+
f'<div class="{warn_cls}"><b>Text gen failed</b> — {text_err}</div>',
|
| 1829 |
+
unsafe_allow_html=True)
|
| 1830 |
+
st.markdown(f'<div class="{text_cls}">{txt}</div>', unsafe_allow_html=True)
|
| 1831 |
+
# Show English original when in German mode
|
| 1832 |
+
if lang == "de":
|
| 1833 |
+
text_en = R.get("text", {}).get("text_en", "")
|
| 1834 |
+
if text_en and text_en != txt:
|
| 1835 |
+
with st.expander("English (original)"):
|
| 1836 |
+
st.markdown(f'<div class="{text_cls}" style="opacity:0.7">{text_en}</div>',
|
| 1837 |
+
unsafe_allow_html=True)
|
| 1838 |
|
| 1839 |
with ci:
|
| 1840 |
+
st.markdown(f'<div class="{sec_cls}">{L["gen_image_label"]}</div>', unsafe_allow_html=True)
|
| 1841 |
ii = R.get("image")
|
| 1842 |
if ii and ii.get("path"):
|
| 1843 |
ip = Path(ii["path"])
|
| 1844 |
backend = ii.get("backend", "unknown")
|
| 1845 |
|
| 1846 |
+
if backend == "retrieval" and R.get("backend") == "generative":
|
| 1847 |
+
if ii.get("credit_error"):
|
| 1848 |
+
st.markdown(
|
| 1849 |
+
f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
|
| 1850 |
+
f'using retrieval fallback.</div>',
|
| 1851 |
+
unsafe_allow_html=True)
|
| 1852 |
+
else:
|
| 1853 |
+
sim = ii.get("similarity", 0)
|
| 1854 |
+
st.markdown(
|
| 1855 |
+
f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
|
| 1856 |
+
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|
| 1857 |
+
unsafe_allow_html=True)
|
| 1858 |
|
| 1859 |
if ip.exists():
|
| 1860 |
st.image(str(ip), use_container_width=True)
|
| 1861 |
model = ii.get("model", "")
|
| 1862 |
if backend == "generative":
|
| 1863 |
+
cap = f"\U0001f3a8 Pixela hat gemalt mit **{model}**" if kid_mode and lang == "de" else (
|
| 1864 |
+
f"\U0001f3a8 Pixela painted with **{model}**" if kid_mode else f"Generated via **{model}**")
|
| 1865 |
+
st.caption(cap)
|
| 1866 |
else:
|
| 1867 |
sim = ii.get("similarity", 0)
|
| 1868 |
dom = ii.get("domain", "other")
|
| 1869 |
ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
|
| 1870 |
st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
|
| 1871 |
else:
|
| 1872 |
+
st.info("No image." if not kid_mode else "\U0001f3a8 Kein Bild." if lang == "de" else "\U0001f3a8 No image.")
|
| 1873 |
|
| 1874 |
with ca:
|
| 1875 |
+
st.markdown(f'<div class="{sec_cls}">{L["gen_audio_label"]}</div>', unsafe_allow_html=True)
|
| 1876 |
ai = R.get("audio")
|
| 1877 |
if ai and ai.get("path"):
|
| 1878 |
ap = Path(ai["path"])
|
| 1879 |
backend = ai.get("backend", "unknown")
|
| 1880 |
|
| 1881 |
+
if backend == "retrieval" and R.get("backend") == "generative":
|
| 1882 |
+
if ai.get("credit_error"):
|
| 1883 |
+
st.markdown(
|
| 1884 |
+
f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
|
| 1885 |
+
f'using retrieval fallback.</div>',
|
| 1886 |
+
unsafe_allow_html=True)
|
| 1887 |
+
else:
|
| 1888 |
+
sim = ai.get("similarity", 0)
|
| 1889 |
+
st.markdown(
|
| 1890 |
+
f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
|
| 1891 |
+
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|
| 1892 |
+
unsafe_allow_html=True)
|
| 1893 |
|
| 1894 |
if ap.exists():
|
| 1895 |
st.audio(str(ap))
|
| 1896 |
model = ai.get("model", "")
|
| 1897 |
if backend == "generative":
|
| 1898 |
+
cap = f"\U0001f3b5 Soundo spielt mit **{model}**" if kid_mode and lang == "de" else (
|
| 1899 |
+
f"\U0001f3b5 Soundo plays with **{model}**" if kid_mode else f"Generated via **{model}**")
|
| 1900 |
+
st.caption(cap)
|
| 1901 |
else:
|
| 1902 |
sim = ai.get("similarity", 0)
|
| 1903 |
st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
|
| 1904 |
else:
|
| 1905 |
+
st.info("No audio." if not kid_mode else "\U0001f3b5 Kein Audio." if lang == "de" else "\U0001f3b5 No audio.")
|
| 1906 |
|
| 1907 |
st.markdown("---")
|
| 1908 |
|
| 1909 |
+
# Expandable details (hidden in kid mode to keep it simple)
|
| 1910 |
+
if not kid_mode:
|
| 1911 |
+
with st.expander("Semantic Plan"):
|
| 1912 |
+
td = R.get("text", {})
|
| 1913 |
+
plan = td.get("plan")
|
| 1914 |
+
if plan:
|
| 1915 |
+
p1, p2 = st.columns(2)
|
| 1916 |
+
with p1:
|
| 1917 |
+
dash = "\u2014"
|
| 1918 |
+
dot = "\u00b7"
|
| 1919 |
+
scene = plan.get("scene_summary", dash)
|
| 1920 |
+
domain = plan.get("domain", dash)
|
| 1921 |
+
core = plan.get("core_semantics", {})
|
| 1922 |
+
setting = core.get("setting", dash)
|
| 1923 |
+
tod = core.get("time_of_day", dash)
|
| 1924 |
+
weather = core.get("weather", dash)
|
| 1925 |
+
subjects = ", ".join(core.get("main_subjects", []))
|
| 1926 |
+
st.markdown(f"**Scene** {scene}")
|
| 1927 |
+
st.markdown(f"**Domain** {domain}")
|
| 1928 |
+
st.markdown(f"**Setting** {setting} {dot} **Time** {tod} {dot} **Weather** {weather}")
|
| 1929 |
+
st.markdown(f"**Subjects** {subjects}")
|
| 1930 |
+
with p2:
|
| 1931 |
+
st.markdown("**Image prompt**")
|
| 1932 |
+
st.code(td.get("image_prompt", ""), language=None)
|
| 1933 |
+
st.markdown("**Audio prompt**")
|
| 1934 |
+
st.code(td.get("audio_prompt", ""), language=None)
|
|
|
|
|
|
|
|
|
|
| 1935 |
else:
|
| 1936 |
+
mode = R.get("mode", "direct")
|
| 1937 |
+
if mode == "direct":
|
| 1938 |
+
st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
|
| 1939 |
+
else:
|
| 1940 |
+
st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
|
| 1941 |
+
|
| 1942 |
+
with st.expander("Generation Details"):
|
| 1943 |
+
r1, r2 = st.columns(2)
|
| 1944 |
+
with r1:
|
| 1945 |
+
ii = R.get("image")
|
| 1946 |
+
if ii:
|
| 1947 |
+
backend = ii.get("backend", "unknown")
|
| 1948 |
+
model = ii.get("model", "")
|
| 1949 |
+
if backend == "generative":
|
| 1950 |
+
st.markdown(f"**Image** generated via **{model}**")
|
| 1951 |
+
st.markdown(f"Prompt: *{R.get('text', {}).get('image_prompt', '')}*")
|
| 1952 |
+
elif ii.get("top_5"):
|
| 1953 |
+
st.markdown("**Image** (retrieval fallback)")
|
| 1954 |
+
bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
|
| 1955 |
+
st.markdown(bars, unsafe_allow_html=True)
|
| 1956 |
+
else:
|
| 1957 |
+
st.write("No image data.")
|
| 1958 |
+
with r2:
|
| 1959 |
+
ai = R.get("audio")
|
| 1960 |
+
if ai:
|
| 1961 |
+
backend = ai.get("backend", "unknown")
|
| 1962 |
+
model = ai.get("model", "")
|
| 1963 |
+
if backend == "generative":
|
| 1964 |
+
st.markdown(f"**Audio** generated via **{model}**")
|
| 1965 |
+
st.markdown(f"Prompt: *{R.get('text', {}).get('audio_prompt', '')}*")
|
| 1966 |
+
elif ai.get("top_5"):
|
| 1967 |
+
st.markdown("**Audio** (retrieval fallback)")
|
| 1968 |
+
bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
|
| 1969 |
+
st.markdown(bars, unsafe_allow_html=True)
|
| 1970 |
+
else:
|
| 1971 |
+
st.write("No audio data.")
|
| 1972 |
+
|
| 1973 |
+
with st.expander("Full Coherence Report"):
|
| 1974 |
+
if coh:
|
| 1975 |
+
st.json(coh)
|
| 1976 |
else:
|
| 1977 |
+
st.write("No data.")
|
| 1978 |
+
else:
|
| 1979 |
+
# Kid mode: simple "how it works" expander instead of technical details
|
| 1980 |
+
label_how = "\U0001f914 Wie funktioniert das?" if lang == "de" else "\U0001f914 How does it work?"
|
| 1981 |
+
with st.expander(label_how):
|
| 1982 |
+
if lang == "de":
|
| 1983 |
+
st.markdown(
|
| 1984 |
+
"1. **Textino** \U0001f916 liest deine Beschreibung und schreibt eine Geschichte\n"
|
| 1985 |
+
"2. **Pixela** \U0001f3a8 malt ein Bild, das zur Geschichte passt\n"
|
| 1986 |
+
"3. **Soundo** \U0001f3b5 erzeugt Ger\u00e4usche und Musik dazu\n"
|
| 1987 |
+
"4. Dann pr\u00fcfen wir, ob alles gut zusammenpasst! \u2b50"
|
| 1988 |
+
)
|
|
|
|
| 1989 |
else:
|
| 1990 |
+
st.markdown(
|
| 1991 |
+
"1. **Textino** \U0001f916 reads your description and writes a story\n"
|
| 1992 |
+
"2. **Pixela** \U0001f3a8 paints a picture that matches the story\n"
|
| 1993 |
+
"3. **Soundo** \U0001f3b5 creates sounds and music for it\n"
|
| 1994 |
+
"4. Then we check if everything fits together! \u2b50"
|
| 1995 |
+
)
|
|
|
|
| 1996 |
|
| 1997 |
|
| 1998 |
if __name__ == "__main__":
|
src/coherence/calibration.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distribution Normalization for cMSCI.
|
| 3 |
+
|
| 4 |
+
Scores from different embedding spaces (CLIP vs CLAP) and different
|
| 5 |
+
pairwise channels (st_i, st_a, gram_volume) have different natural
|
| 6 |
+
distributions. Z-score normalization makes them comparable.
|
| 7 |
+
|
| 8 |
+
The ReferenceDistribution class fits mean/std from existing experiment
|
| 9 |
+
data and normalizes new scores to z-scores or percentile ranks.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from scipy import stats as sp_stats
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ReferenceDistribution:
|
| 26 |
+
"""
|
| 27 |
+
Stores mean/std for a single score channel and provides normalization.
|
| 28 |
+
|
| 29 |
+
Usage:
|
| 30 |
+
ref = ReferenceDistribution()
|
| 31 |
+
ref.fit(list_of_scores)
|
| 32 |
+
z = ref.normalize(new_score) # z-score
|
| 33 |
+
p = ref.percentile(new_score) # percentile rank [0, 1]
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, name: str = ""):
|
| 37 |
+
self.name = name
|
| 38 |
+
self.mean: float = 0.0
|
| 39 |
+
self.std: float = 1.0
|
| 40 |
+
self.n: int = 0
|
| 41 |
+
self._sorted_values: Optional[np.ndarray] = None
|
| 42 |
+
|
| 43 |
+
def fit(self, scores: List[float]) -> None:
|
| 44 |
+
"""Fit the distribution from a list of observed scores."""
|
| 45 |
+
arr = np.array(scores, dtype=np.float64)
|
| 46 |
+
self.n = len(arr)
|
| 47 |
+
self.mean = float(np.mean(arr))
|
| 48 |
+
self.std = float(np.std(arr, ddof=1)) if self.n > 1 else 1.0
|
| 49 |
+
if self.std < 1e-10:
|
| 50 |
+
self.std = 1.0
|
| 51 |
+
self._sorted_values = np.sort(arr)
|
| 52 |
+
|
| 53 |
+
def normalize(self, score: float) -> float:
|
| 54 |
+
"""Z-score normalization: (score - mean) / std."""
|
| 55 |
+
return float((score - self.mean) / self.std)
|
| 56 |
+
|
| 57 |
+
def percentile(self, score: float) -> float:
|
| 58 |
+
"""
|
| 59 |
+
Percentile rank of score within the reference distribution.
|
| 60 |
+
|
| 61 |
+
Returns a value in [0, 1] where 0.5 = median of reference.
|
| 62 |
+
"""
|
| 63 |
+
if self._sorted_values is None or len(self._sorted_values) == 0:
|
| 64 |
+
return 0.5
|
| 65 |
+
rank = np.searchsorted(self._sorted_values, score, side="right")
|
| 66 |
+
return float(rank / len(self._sorted_values))
|
| 67 |
+
|
| 68 |
+
def to_dict(self) -> Dict:
|
| 69 |
+
return {
|
| 70 |
+
"name": self.name,
|
| 71 |
+
"mean": self.mean,
|
| 72 |
+
"std": self.std,
|
| 73 |
+
"n": self.n,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_dict(cls, d: Dict) -> "ReferenceDistribution":
|
| 78 |
+
obj = cls(name=d.get("name", ""))
|
| 79 |
+
obj.mean = d["mean"]
|
| 80 |
+
obj.std = d["std"]
|
| 81 |
+
obj.n = d.get("n", 0)
|
| 82 |
+
return obj
|
| 83 |
+
|
| 84 |
+
def save(self, path: str) -> None:
|
| 85 |
+
with open(path, "w") as f:
|
| 86 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def load(cls, path: str) -> "ReferenceDistribution":
|
| 90 |
+
with open(path) as f:
|
| 91 |
+
return cls.from_dict(json.load(f))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class CalibrationStore:
|
| 95 |
+
"""
|
| 96 |
+
Collection of ReferenceDistributions for all score channels.
|
| 97 |
+
|
| 98 |
+
Provides save/load for the full calibration state.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self):
|
| 102 |
+
self.distributions: Dict[str, ReferenceDistribution] = {}
|
| 103 |
+
|
| 104 |
+
def add(self, name: str, scores: List[float]) -> ReferenceDistribution:
|
| 105 |
+
ref = ReferenceDistribution(name=name)
|
| 106 |
+
ref.fit(scores)
|
| 107 |
+
self.distributions[name] = ref
|
| 108 |
+
logger.info(
|
| 109 |
+
"Calibration[%s]: mean=%.4f, std=%.4f, n=%d",
|
| 110 |
+
name, ref.mean, ref.std, ref.n,
|
| 111 |
+
)
|
| 112 |
+
return ref
|
| 113 |
+
|
| 114 |
+
def normalize(self, name: str, score: float) -> float:
|
| 115 |
+
if name not in self.distributions:
|
| 116 |
+
return score
|
| 117 |
+
return self.distributions[name].normalize(score)
|
| 118 |
+
|
| 119 |
+
def percentile(self, name: str, score: float) -> float:
|
| 120 |
+
if name not in self.distributions:
|
| 121 |
+
return 0.5
|
| 122 |
+
return self.distributions[name].percentile(score)
|
| 123 |
+
|
| 124 |
+
def save(self, path: str) -> None:
|
| 125 |
+
data = {name: ref.to_dict() for name, ref in self.distributions.items()}
|
| 126 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
with open(path, "w") as f:
|
| 128 |
+
json.dump(data, f, indent=2)
|
| 129 |
+
logger.info("Calibration saved to %s", path)
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def load(cls, path: str) -> "CalibrationStore":
|
| 133 |
+
store = cls()
|
| 134 |
+
with open(path) as f:
|
| 135 |
+
data = json.load(f)
|
| 136 |
+
for name, d in data.items():
|
| 137 |
+
store.distributions[name] = ReferenceDistribution.from_dict(d)
|
| 138 |
+
logger.info("Calibration loaded from %s (%d channels)", path, len(store.distributions))
|
| 139 |
+
return store
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def has_channel(store: CalibrationStore, name: str) -> bool:
|
| 143 |
+
"""Check if a calibration channel exists in the store."""
|
| 144 |
+
return name in store.distributions
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def extend_calibration_with_exmcr(
|
| 148 |
+
store: CalibrationStore,
|
| 149 |
+
gram_coh_ia_scores: List[float],
|
| 150 |
+
gram_coh_tia_scores: Optional[List[float]] = None,
|
| 151 |
+
) -> CalibrationStore:
|
| 152 |
+
"""
|
| 153 |
+
Extend calibration store with ExMCR-derived channels.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
store: Existing CalibrationStore to extend.
|
| 157 |
+
gram_coh_ia_scores: Gram coherence of (image_clip, ExMCR(audio_clap)) pairs.
|
| 158 |
+
gram_coh_tia_scores: Optional 3-way gram coherence of (text, image, ExMCR(audio)).
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Extended CalibrationStore (same object, modified in place).
|
| 162 |
+
"""
|
| 163 |
+
if gram_coh_ia_scores:
|
| 164 |
+
store.add("gram_coh_ia_exmcr", gram_coh_ia_scores)
|
| 165 |
+
if gram_coh_tia_scores:
|
| 166 |
+
store.add("gram_coh_tia", gram_coh_tia_scores)
|
| 167 |
+
return store
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def extend_calibration_with_uncertainty(
|
| 171 |
+
store: CalibrationStore,
|
| 172 |
+
uncertainty_ti_scores: List[float],
|
| 173 |
+
uncertainty_ta_scores: Optional[List[float]] = None,
|
| 174 |
+
) -> CalibrationStore:
|
| 175 |
+
"""
|
| 176 |
+
Extend calibration store with ProbVLM uncertainty channels.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
store: Existing CalibrationStore to extend.
|
| 180 |
+
uncertainty_ti_scores: Per-sample mean uncertainty for text-image (CLIP adapter).
|
| 181 |
+
uncertainty_ta_scores: Per-sample mean uncertainty for text-audio (CLAP adapter).
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Extended CalibrationStore (same object, modified in place).
|
| 185 |
+
"""
|
| 186 |
+
if uncertainty_ti_scores:
|
| 187 |
+
store.add("uncertainty_ti", uncertainty_ti_scores)
|
| 188 |
+
if uncertainty_ta_scores:
|
| 189 |
+
store.add("uncertainty_ta", uncertainty_ta_scores)
|
| 190 |
+
# Combined uncertainty channel
|
| 191 |
+
if uncertainty_ti_scores and uncertainty_ta_scores:
|
| 192 |
+
combined = [
|
| 193 |
+
(ti + ta) / 2.0
|
| 194 |
+
for ti, ta in zip(uncertainty_ti_scores, uncertainty_ta_scores)
|
| 195 |
+
]
|
| 196 |
+
store.add("uncertainty_mean", combined)
|
| 197 |
+
return store
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def build_reference_distributions(
|
| 201 |
+
rq1_results_path: str,
|
| 202 |
+
) -> CalibrationStore:
|
| 203 |
+
"""
|
| 204 |
+
Build reference distributions from existing RQ1 baseline results.
|
| 205 |
+
|
| 206 |
+
Extracts st_i, st_a, and msci scores from baseline condition only
|
| 207 |
+
(matched image + audio), fitting a distribution for each channel.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
rq1_results_path: Path to rq1_results.json
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
CalibrationStore with fitted distributions for st_i, st_a, msci
|
| 214 |
+
"""
|
| 215 |
+
with open(rq1_results_path) as f:
|
| 216 |
+
data = json.load(f)
|
| 217 |
+
|
| 218 |
+
st_i_scores = []
|
| 219 |
+
st_a_scores = []
|
| 220 |
+
msci_scores = []
|
| 221 |
+
|
| 222 |
+
for r in data["results"]:
|
| 223 |
+
if r.get("condition") != "baseline":
|
| 224 |
+
continue
|
| 225 |
+
if r.get("st_i") is not None:
|
| 226 |
+
st_i_scores.append(r["st_i"])
|
| 227 |
+
if r.get("st_a") is not None:
|
| 228 |
+
st_a_scores.append(r["st_a"])
|
| 229 |
+
if r.get("msci") is not None:
|
| 230 |
+
msci_scores.append(r["msci"])
|
| 231 |
+
|
| 232 |
+
store = CalibrationStore()
|
| 233 |
+
if st_i_scores:
|
| 234 |
+
store.add("st_i", st_i_scores)
|
| 235 |
+
if st_a_scores:
|
| 236 |
+
store.add("st_a", st_a_scores)
|
| 237 |
+
if msci_scores:
|
| 238 |
+
store.add("msci", msci_scores)
|
| 239 |
+
|
| 240 |
+
# GRAM coherence distributions (1 - gram_volume) for gram calibration mode
|
| 241 |
+
# gram_volume = sqrt(1 - cos^2), so gram_coherence = 1 - sqrt(1 - cos^2)
|
| 242 |
+
if st_i_scores:
|
| 243 |
+
gram_coh_ti = [1.0 - np.sqrt(max(0, 1 - s**2)) for s in st_i_scores]
|
| 244 |
+
store.add("gram_coh_ti", gram_coh_ti)
|
| 245 |
+
if st_a_scores:
|
| 246 |
+
gram_coh_ta = [1.0 - np.sqrt(max(0, 1 - s**2)) for s in st_a_scores]
|
| 247 |
+
store.add("gram_coh_ta", gram_coh_ta)
|
| 248 |
+
|
| 249 |
+
return store
|
src/coherence/cmsci_engine.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calibrated Multimodal Semantic Coherence Index (cMSCI) Engine.
|
| 3 |
+
|
| 4 |
+
Replaces fixed weighted averaging (MSCI) with a principled pipeline:
|
| 5 |
+
1. Gramian Volume: geometric coherence of embedding vectors
|
| 6 |
+
2. Distribution Normalization: z-score calibration per channel
|
| 7 |
+
3. Contrastive Margin: comparison against hard negatives
|
| 8 |
+
4. Cross-Space Alignment: Ex-MCR projects CLAP→CLIP for 3-way GRAM
|
| 9 |
+
5. Probabilistic Uncertainty: MC sampling for confidence intervals
|
| 10 |
+
|
| 11 |
+
The CalibratedCoherenceEngine runs alongside CoherenceEngine (not replacing
|
| 12 |
+
it) and returns both legacy MSCI and new cMSCI scores for comparison.
|
| 13 |
+
|
| 14 |
+
Variant progression:
|
| 15 |
+
A: MSCI (baseline, weighted cosine average)
|
| 16 |
+
B: GRAM-only (geometric, no calibration)
|
| 17 |
+
C: GRAM + z-norm (normalized geometric)
|
| 18 |
+
D: GRAM + z-norm + contrastive (calibrated geometric)
|
| 19 |
+
E: GRAM + z-norm + contrastive + Ex-MCR (3-way calibrated)
|
| 20 |
+
F: Full cMSCI (probabilistic + calibrated + 3-way)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import logging
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, List, Optional
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
from src.coherence.gram_volume import (
|
| 32 |
+
gram_volume_2d,
|
| 33 |
+
gram_volume_3d,
|
| 34 |
+
gram_volume_nd,
|
| 35 |
+
normalized_gram_coherence,
|
| 36 |
+
)
|
| 37 |
+
from src.config.settings import (
|
| 38 |
+
CMSCI_MARGIN_ALPHA,
|
| 39 |
+
CMSCI_CHANNEL_WEIGHT_TI,
|
| 40 |
+
CMSCI_CALIBRATION_MODE,
|
| 41 |
+
CMSCI_W_3D,
|
| 42 |
+
CMSCI_GAMMA,
|
| 43 |
+
)
|
| 44 |
+
from src.embeddings.aligned_embeddings import AlignedEmbedder
|
| 45 |
+
from src.embeddings.similarity import cosine_similarity
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CalibratedCoherenceEngine:
|
| 51 |
+
"""
|
| 52 |
+
Uncertainty-aware, geometrically-grounded tri-modal coherence engine.
|
| 53 |
+
|
| 54 |
+
Computes cMSCI alongside legacy MSCI for comparison.
|
| 55 |
+
|
| 56 |
+
Usage:
|
| 57 |
+
engine = CalibratedCoherenceEngine()
|
| 58 |
+
result = engine.evaluate("A beach at sunset", "beach.jpg", "waves.wav")
|
| 59 |
+
print(result["cmsci"]) # Calibrated score
|
| 60 |
+
print(result["msci"]) # Legacy score (for comparison)
|
| 61 |
+
print(result["variant_scores"]) # Scores for each variant A-F
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
target_dim: int = 512,
|
| 67 |
+
calibration_path: Optional[str] = None,
|
| 68 |
+
exmcr_weights_path: Optional[str] = None,
|
| 69 |
+
bridge_path: Optional[str] = None,
|
| 70 |
+
prob_clip_adapter_path: Optional[str] = None,
|
| 71 |
+
prob_clap_adapter_path: Optional[str] = None,
|
| 72 |
+
negative_bank_enabled: bool = True,
|
| 73 |
+
):
|
| 74 |
+
self.embedder = AlignedEmbedder(target_dim=target_dim)
|
| 75 |
+
|
| 76 |
+
# Calibration store (Phase 2)
|
| 77 |
+
self._calibration = None
|
| 78 |
+
if calibration_path and Path(calibration_path).exists():
|
| 79 |
+
from src.coherence.calibration import CalibrationStore
|
| 80 |
+
self._calibration = CalibrationStore.load(calibration_path)
|
| 81 |
+
logger.info("Calibration loaded from %s", calibration_path)
|
| 82 |
+
|
| 83 |
+
# Negative bank (Phase 2)
|
| 84 |
+
self._negative_bank = None
|
| 85 |
+
if negative_bank_enabled:
|
| 86 |
+
try:
|
| 87 |
+
from src.coherence.negative_bank import NegativeBank
|
| 88 |
+
self._negative_bank = NegativeBank()
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.warning("Negative bank disabled: %s", e)
|
| 91 |
+
|
| 92 |
+
# Ex-MCR projector (Phase 3 — projects CLAP into CLIP space)
|
| 93 |
+
self._exmcr = None
|
| 94 |
+
if exmcr_weights_path:
|
| 95 |
+
from src.embeddings.space_alignment import ExMCRProjector
|
| 96 |
+
self._exmcr = ExMCRProjector(weights_path=exmcr_weights_path)
|
| 97 |
+
if self._exmcr.is_identity:
|
| 98 |
+
logger.info("Ex-MCR in identity mode (no weights)")
|
| 99 |
+
else:
|
| 100 |
+
logger.info("Ex-MCR projector active")
|
| 101 |
+
|
| 102 |
+
# Cross-Space Bridge (projects CLIP image + CLAP audio → shared 256-d)
|
| 103 |
+
self._bridge = None
|
| 104 |
+
if bridge_path and Path(bridge_path).exists():
|
| 105 |
+
from src.embeddings.cross_space_bridge import CrossSpaceBridge
|
| 106 |
+
self._bridge = CrossSpaceBridge.load(bridge_path)
|
| 107 |
+
logger.info("CrossSpaceBridge loaded from %s", bridge_path)
|
| 108 |
+
|
| 109 |
+
# Probabilistic adapters (Phase 4)
|
| 110 |
+
self._prob_clip = None
|
| 111 |
+
self._prob_clap = None
|
| 112 |
+
if prob_clip_adapter_path and Path(prob_clip_adapter_path).exists():
|
| 113 |
+
from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
|
| 114 |
+
self._prob_clip = ProbabilisticAdapter.load(prob_clip_adapter_path)
|
| 115 |
+
logger.info("CLIP probabilistic adapter loaded")
|
| 116 |
+
if prob_clap_adapter_path and Path(prob_clap_adapter_path).exists():
|
| 117 |
+
from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
|
| 118 |
+
self._prob_clap = ProbabilisticAdapter.load(prob_clap_adapter_path)
|
| 119 |
+
logger.info("CLAP probabilistic adapter loaded")
|
| 120 |
+
|
| 121 |
+
def evaluate(
|
| 122 |
+
self,
|
| 123 |
+
text: str,
|
| 124 |
+
image_path: Optional[str] = None,
|
| 125 |
+
audio_path: Optional[str] = None,
|
| 126 |
+
domain: str = "",
|
| 127 |
+
n_mc_samples: int = 100,
|
| 128 |
+
) -> Dict[str, Any]:
|
| 129 |
+
"""
|
| 130 |
+
Evaluate multimodal coherence with full cMSCI pipeline.
|
| 131 |
+
|
| 132 |
+
Returns both legacy MSCI and cMSCI scores along with all
|
| 133 |
+
intermediate computations for ablation analysis.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
text: Text prompt.
|
| 137 |
+
image_path: Path to image file.
|
| 138 |
+
audio_path: Path to audio file.
|
| 139 |
+
domain: Domain hint for negative bank (e.g., "nature").
|
| 140 |
+
n_mc_samples: Number of MC samples for uncertainty.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Dict with keys:
|
| 144 |
+
msci: Legacy MSCI score (weighted cosine average)
|
| 145 |
+
cmsci: Calibrated cMSCI score
|
| 146 |
+
scores: Raw pairwise scores (st_i, st_a, si_a)
|
| 147 |
+
gram: Gramian volume scores
|
| 148 |
+
calibration: Z-normalized scores
|
| 149 |
+
contrastive: Contrastive margin results
|
| 150 |
+
uncertainty: MC sampling uncertainty (if adapters loaded)
|
| 151 |
+
variant_scores: Scores for each variant A-F
|
| 152 |
+
"""
|
| 153 |
+
# ── Embed ──────────────────────────────────────────────
|
| 154 |
+
emb_text_clip = self.embedder.embed_text(text)
|
| 155 |
+
emb_text_clap = self.embedder.embed_text_for_audio(text) if audio_path else None
|
| 156 |
+
emb_image = self.embedder.embed_image(image_path) if image_path else None
|
| 157 |
+
emb_audio = self.embedder.embed_audio(audio_path) if audio_path else None
|
| 158 |
+
|
| 159 |
+
# ── Legacy MSCI (Variant A) ────────────────────────────
|
| 160 |
+
st_i = None
|
| 161 |
+
st_a = None
|
| 162 |
+
si_a = None
|
| 163 |
+
|
| 164 |
+
if emb_text_clip is not None and emb_image is not None:
|
| 165 |
+
st_i = float(round(cosine_similarity(emb_text_clip, emb_image), 4))
|
| 166 |
+
if emb_text_clap is not None and emb_audio is not None:
|
| 167 |
+
st_a = float(round(cosine_similarity(emb_text_clap, emb_audio), 4))
|
| 168 |
+
|
| 169 |
+
available = {}
|
| 170 |
+
if st_i is not None:
|
| 171 |
+
available["st_i"] = st_i
|
| 172 |
+
if st_a is not None:
|
| 173 |
+
available["st_a"] = st_a
|
| 174 |
+
|
| 175 |
+
weights = {"st_i": 0.45, "st_a": 0.45, "si_a": 0.10}
|
| 176 |
+
if len(available) >= 2:
|
| 177 |
+
total_w = sum(weights[k] for k in available if k in weights)
|
| 178 |
+
msci = sum(available[k] * weights[k] for k in available if k in weights) / max(total_w, 1e-6)
|
| 179 |
+
elif len(available) == 1:
|
| 180 |
+
msci = list(available.values())[0]
|
| 181 |
+
else:
|
| 182 |
+
msci = None
|
| 183 |
+
|
| 184 |
+
variant_a = msci
|
| 185 |
+
|
| 186 |
+
# ── Gramian Volume (Variant B) ─────────────────────────
|
| 187 |
+
gram_ti = None
|
| 188 |
+
gram_ta = None
|
| 189 |
+
gram_tia = None
|
| 190 |
+
gram_coherence_2way = None
|
| 191 |
+
|
| 192 |
+
if emb_text_clip is not None and emb_image is not None:
|
| 193 |
+
gram_ti = gram_volume_2d(emb_text_clip, emb_image)
|
| 194 |
+
|
| 195 |
+
if emb_text_clap is not None and emb_audio is not None:
|
| 196 |
+
gram_ta = gram_volume_2d(emb_text_clap, emb_audio)
|
| 197 |
+
|
| 198 |
+
# 2-way GRAM coherence (average of text-image and text-audio gram coherences)
|
| 199 |
+
gram_coherences = []
|
| 200 |
+
if gram_ti is not None:
|
| 201 |
+
gram_coherences.append(normalized_gram_coherence(gram_ti))
|
| 202 |
+
if gram_ta is not None:
|
| 203 |
+
gram_coherences.append(normalized_gram_coherence(gram_ta))
|
| 204 |
+
|
| 205 |
+
if gram_coherences:
|
| 206 |
+
gram_coherence_2way = float(np.mean(gram_coherences))
|
| 207 |
+
|
| 208 |
+
variant_b = gram_coherence_2way
|
| 209 |
+
|
| 210 |
+
# ── Z-Score Normalization (Variant C) ──────────────────
|
| 211 |
+
z_st_i = None
|
| 212 |
+
z_st_a = None
|
| 213 |
+
z_gram_ti = None
|
| 214 |
+
z_gram_ta = None
|
| 215 |
+
variant_c = variant_b # default to B if no calibration
|
| 216 |
+
|
| 217 |
+
# Channel weight from settings (optimized via LOO-CV)
|
| 218 |
+
w_ti = CMSCI_CHANNEL_WEIGHT_TI
|
| 219 |
+
cal_mode = CMSCI_CALIBRATION_MODE
|
| 220 |
+
|
| 221 |
+
if self._calibration is not None:
|
| 222 |
+
if st_i is not None:
|
| 223 |
+
z_st_i = self._calibration.normalize("st_i", st_i)
|
| 224 |
+
if st_a is not None:
|
| 225 |
+
z_st_a = self._calibration.normalize("st_a", st_a)
|
| 226 |
+
|
| 227 |
+
# GRAM coherence z-scores (for gram calibration mode)
|
| 228 |
+
if gram_ti is not None:
|
| 229 |
+
gram_coh_ti = normalized_gram_coherence(gram_ti)
|
| 230 |
+
z_gram_ti = self._calibration.normalize("gram_coh_ti", gram_coh_ti)
|
| 231 |
+
if gram_ta is not None:
|
| 232 |
+
gram_coh_ta = normalized_gram_coherence(gram_ta)
|
| 233 |
+
z_gram_ta = self._calibration.normalize("gram_coh_ta", gram_coh_ta)
|
| 234 |
+
|
| 235 |
+
# Select calibration mode: cosine z-scores or gram coherence z-scores
|
| 236 |
+
if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
|
| 237 |
+
z_mean = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
|
| 238 |
+
else:
|
| 239 |
+
# Cosine mode (original behavior) with weighted channels
|
| 240 |
+
z_coherences = []
|
| 241 |
+
z_weights = []
|
| 242 |
+
if z_st_i is not None:
|
| 243 |
+
z_coherences.append(z_st_i)
|
| 244 |
+
z_weights.append(w_ti)
|
| 245 |
+
if z_st_a is not None:
|
| 246 |
+
z_coherences.append(z_st_a)
|
| 247 |
+
z_weights.append(1.0 - w_ti)
|
| 248 |
+
|
| 249 |
+
if z_coherences:
|
| 250 |
+
total_w = sum(z_weights)
|
| 251 |
+
z_mean = sum(z * wt for z, wt in zip(z_coherences, z_weights)) / total_w
|
| 252 |
+
else:
|
| 253 |
+
z_mean = None
|
| 254 |
+
|
| 255 |
+
if z_mean is not None:
|
| 256 |
+
# Map z-scores back to [0,1] via sigmoid for interpretability
|
| 257 |
+
variant_c = float(1.0 / (1.0 + np.exp(-z_mean)))
|
| 258 |
+
|
| 259 |
+
# ── Contrastive Margin (Variant D) ─────────────────────
|
| 260 |
+
contrastive_result = None
|
| 261 |
+
variant_d = variant_c # default to C if no negatives
|
| 262 |
+
margin_alpha = CMSCI_MARGIN_ALPHA
|
| 263 |
+
|
| 264 |
+
if self._negative_bank is not None and gram_coherence_2way is not None:
|
| 265 |
+
matched_volume = float(np.mean([v for v in [gram_ti, gram_ta] if v is not None]))
|
| 266 |
+
contrastive_result = self._negative_bank.compute_contrastive_margin(
|
| 267 |
+
matched_volume=matched_volume,
|
| 268 |
+
text_clip_emb=emb_text_clip,
|
| 269 |
+
image_emb=emb_image,
|
| 270 |
+
text_clap_emb=emb_text_clap,
|
| 271 |
+
audio_emb=emb_audio,
|
| 272 |
+
domain=domain,
|
| 273 |
+
k=5,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
if contrastive_result["n_negatives"] > 0:
|
| 277 |
+
# cMSCI_D = sigmoid(z_mean + alpha * margin)
|
| 278 |
+
# alpha amplifies the contrastive signal at the sigmoid operating point
|
| 279 |
+
margin = contrastive_result["margin"]
|
| 280 |
+
|
| 281 |
+
# Use the same calibration mode and weighting as Variant C
|
| 282 |
+
if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
|
| 283 |
+
z_mean_d = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
|
| 284 |
+
else:
|
| 285 |
+
z_coherences_d = []
|
| 286 |
+
z_weights_d = []
|
| 287 |
+
if z_st_i is not None:
|
| 288 |
+
z_coherences_d.append(z_st_i)
|
| 289 |
+
z_weights_d.append(w_ti)
|
| 290 |
+
elif st_i is not None:
|
| 291 |
+
z_coherences_d.append(st_i)
|
| 292 |
+
z_weights_d.append(w_ti)
|
| 293 |
+
if z_st_a is not None:
|
| 294 |
+
z_coherences_d.append(z_st_a)
|
| 295 |
+
z_weights_d.append(1.0 - w_ti)
|
| 296 |
+
elif st_a is not None:
|
| 297 |
+
z_coherences_d.append(st_a)
|
| 298 |
+
z_weights_d.append(1.0 - w_ti)
|
| 299 |
+
|
| 300 |
+
if z_coherences_d:
|
| 301 |
+
total_wd = sum(z_weights_d)
|
| 302 |
+
z_mean_d = sum(z * wt for z, wt in zip(z_coherences_d, z_weights_d)) / total_wd
|
| 303 |
+
else:
|
| 304 |
+
z_mean_d = None
|
| 305 |
+
|
| 306 |
+
if z_mean_d is not None:
|
| 307 |
+
variant_d = float(1.0 / (1.0 + np.exp(-(z_mean_d + margin_alpha * margin))))
|
| 308 |
+
else:
|
| 309 |
+
variant_d = variant_c
|
| 310 |
+
|
| 311 |
+
# ── Cross-Space Complementarity — Variant E ──────────
|
| 312 |
+
# COMPLEMENTARITY: E = sigmoid(z_2d + w_3d * z_compl + alpha * margin)
|
| 313 |
+
# ExMCR projects CLAP audio → CLIP space, enabling measurement of
|
| 314 |
+
# image-audio complementarity (Gramian dispersion in unified space).
|
| 315 |
+
# High complementarity = image and audio contribute unique perspectives.
|
| 316 |
+
# Low complementarity = redundant cross-modal information.
|
| 317 |
+
# z_compl = z_normalize(gram_volume_ia) — positive z = more complementary.
|
| 318 |
+
# w_3d=0 recovers D exactly (safety guarantee).
|
| 319 |
+
audio_projected = None
|
| 320 |
+
variant_e = variant_d # default to D if no projector
|
| 321 |
+
z_compl = None # z-normalized complementarity (exported for optimizer)
|
| 322 |
+
gram_ia_volume = None # raw image-audio Gramian volume
|
| 323 |
+
w_3d = CMSCI_W_3D
|
| 324 |
+
|
| 325 |
+
# Reconstruct D's pre-margin z-score (z_2d) for composition
|
| 326 |
+
z_2d = None
|
| 327 |
+
margin = 0.0
|
| 328 |
+
if contrastive_result is not None and contrastive_result["n_negatives"] > 0:
|
| 329 |
+
margin = contrastive_result["margin"]
|
| 330 |
+
if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
|
| 331 |
+
z_2d = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
|
| 332 |
+
elif z_st_i is not None and z_st_a is not None:
|
| 333 |
+
z_2d = w_ti * z_st_i + (1.0 - w_ti) * z_st_a
|
| 334 |
+
|
| 335 |
+
# Project audio into CLIP space via ExMCR and compute complementarity
|
| 336 |
+
if self._exmcr is not None and not self._exmcr.is_identity:
|
| 337 |
+
if emb_audio is not None:
|
| 338 |
+
audio_projected = self._exmcr.project_audio(emb_audio)
|
| 339 |
+
if emb_image is not None:
|
| 340 |
+
si_a = float(round(cosine_similarity(emb_image, audio_projected), 4))
|
| 341 |
+
# Image-audio Gramian volume = dispersion = complementarity
|
| 342 |
+
gram_ia_volume = gram_volume_2d(emb_image, audio_projected)
|
| 343 |
+
if emb_text_clip is not None and emb_image is not None and audio_projected is not None:
|
| 344 |
+
gram_tia = gram_volume_3d(emb_text_clip, emb_image, audio_projected)
|
| 345 |
+
|
| 346 |
+
# Z-normalize complementarity (volume, NOT coherence)
|
| 347 |
+
# z_compl = -z_gram_ia_coherence (flipped: high volume = high complementarity)
|
| 348 |
+
if gram_ia_volume is not None and self._calibration is not None:
|
| 349 |
+
gram_ia_coherence = normalized_gram_coherence(gram_ia_volume)
|
| 350 |
+
z_gram_ia_coh = self._calibration.normalize("gram_coh_ia_exmcr", gram_ia_coherence)
|
| 351 |
+
z_compl = -z_gram_ia_coh # flip: positive = more complementary
|
| 352 |
+
|
| 353 |
+
# Compose: E = sigmoid(z_2d + w_3d * z_compl + alpha * margin)
|
| 354 |
+
if z_2d is not None:
|
| 355 |
+
logit_e = z_2d + margin_alpha * margin
|
| 356 |
+
if z_compl is not None:
|
| 357 |
+
logit_e += w_3d * z_compl
|
| 358 |
+
variant_e = float(1.0 / (1.0 + np.exp(-logit_e)))
|
| 359 |
+
|
| 360 |
+
# ── Probabilistic Adaptive Weighting (Variant F) ──────
|
| 361 |
+
# ProbVLM drives per-sample channel weights instead of fixed w_ti.
|
| 362 |
+
# adaptive_w = (1/u_ti) / (1/u_ti + 1/u_ta) — trust more confident channel
|
| 363 |
+
# w_ti_final = (1 - gamma) * base_w + gamma * adaptive_w
|
| 364 |
+
# gamma=0 → w_ti_final = base_w → recovers E exactly (safety guarantee)
|
| 365 |
+
# MC sampling remains metadata only (confidence intervals, not scoring).
|
| 366 |
+
uncertainty_result = None
|
| 367 |
+
variant_f = variant_e # default to E
|
| 368 |
+
u_ti = None # per-channel uncertainty (exported for optimizer)
|
| 369 |
+
u_ta = None
|
| 370 |
+
adaptive_w_ti = None
|
| 371 |
+
gamma = CMSCI_GAMMA
|
| 372 |
+
|
| 373 |
+
if self._prob_clip is not None or self._prob_clap is not None:
|
| 374 |
+
mc_volumes = []
|
| 375 |
+
|
| 376 |
+
# Per-channel uncertainty from ProbVLM adapters
|
| 377 |
+
if self._prob_clip is not None and emb_text_clip is not None and emb_image is not None:
|
| 378 |
+
u_text_clip = self._prob_clip.uncertainty(emb_text_clip)
|
| 379 |
+
u_image_clip = self._prob_clip.uncertainty(emb_image)
|
| 380 |
+
u_ti = float(np.mean([u_text_clip, u_image_clip]))
|
| 381 |
+
|
| 382 |
+
# MC samples for confidence interval metadata
|
| 383 |
+
text_samples = self._prob_clip.sample(emb_text_clip, n_mc_samples)
|
| 384 |
+
image_samples = self._prob_clip.sample(emb_image, n_mc_samples)
|
| 385 |
+
for t_s, i_s in zip(text_samples, image_samples):
|
| 386 |
+
mc_volumes.append(gram_volume_2d(t_s, i_s))
|
| 387 |
+
|
| 388 |
+
if self._prob_clap is not None and emb_text_clap is not None and emb_audio is not None:
|
| 389 |
+
u_text_clap = self._prob_clap.uncertainty(emb_text_clap)
|
| 390 |
+
u_audio_clap = self._prob_clap.uncertainty(emb_audio)
|
| 391 |
+
u_ta = float(np.mean([u_text_clap, u_audio_clap]))
|
| 392 |
+
|
| 393 |
+
text_samples = self._prob_clap.sample(emb_text_clap, n_mc_samples)
|
| 394 |
+
audio_samples = self._prob_clap.sample(emb_audio, n_mc_samples)
|
| 395 |
+
for t_s, a_s in zip(text_samples, audio_samples):
|
| 396 |
+
mc_volumes.append(gram_volume_2d(t_s, a_s))
|
| 397 |
+
|
| 398 |
+
# Compute adaptive channel weight from uncertainty
|
| 399 |
+
if u_ti is not None and u_ta is not None and u_ti > 0 and u_ta > 0 and gamma > 0:
|
| 400 |
+
inv_ti = 1.0 / u_ti
|
| 401 |
+
inv_ta = 1.0 / u_ta
|
| 402 |
+
adaptive_w = inv_ti / (inv_ti + inv_ta)
|
| 403 |
+
w_ti_final = (1.0 - gamma) * w_ti + gamma * adaptive_w
|
| 404 |
+
adaptive_w_ti = float(w_ti_final)
|
| 405 |
+
|
| 406 |
+
# Recompute z_2d with adaptive weights
|
| 407 |
+
if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
|
| 408 |
+
z_2d_adaptive = w_ti_final * z_gram_ti + (1.0 - w_ti_final) * z_gram_ta
|
| 409 |
+
elif z_st_i is not None and z_st_a is not None:
|
| 410 |
+
z_2d_adaptive = w_ti_final * z_st_i + (1.0 - w_ti_final) * z_st_a
|
| 411 |
+
else:
|
| 412 |
+
z_2d_adaptive = None
|
| 413 |
+
|
| 414 |
+
if z_2d_adaptive is not None:
|
| 415 |
+
logit_f = z_2d_adaptive + margin_alpha * margin
|
| 416 |
+
if z_compl is not None:
|
| 417 |
+
logit_f += w_3d * z_compl
|
| 418 |
+
variant_f = float(1.0 / (1.0 + np.exp(-logit_f)))
|
| 419 |
+
|
| 420 |
+
# MC sampling for confidence intervals (metadata, NOT scoring)
|
| 421 |
+
if mc_volumes:
|
| 422 |
+
mc_coherences = [normalized_gram_coherence(v) for v in mc_volumes]
|
| 423 |
+
mc_mean = float(np.mean(mc_coherences))
|
| 424 |
+
mc_std = float(np.std(mc_coherences))
|
| 425 |
+
mc_ci_lower = float(np.percentile(mc_coherences, 2.5))
|
| 426 |
+
mc_ci_upper = float(np.percentile(mc_coherences, 97.5))
|
| 427 |
+
else:
|
| 428 |
+
mc_mean = mc_std = mc_ci_lower = mc_ci_upper = None
|
| 429 |
+
|
| 430 |
+
uncertainty_result = {
|
| 431 |
+
"mc_mean": round(mc_mean, 4) if mc_mean is not None else None,
|
| 432 |
+
"mc_std": round(mc_std, 4) if mc_std is not None else None,
|
| 433 |
+
"mc_ci_lower": round(mc_ci_lower, 4) if mc_ci_lower is not None else None,
|
| 434 |
+
"mc_ci_upper": round(mc_ci_upper, 4) if mc_ci_upper is not None else None,
|
| 435 |
+
"u_ti": round(u_ti, 6) if u_ti is not None else None,
|
| 436 |
+
"u_ta": round(u_ta, 6) if u_ta is not None else None,
|
| 437 |
+
"adaptive_w_ti": round(adaptive_w_ti, 4) if adaptive_w_ti is not None else None,
|
| 438 |
+
"gamma": gamma,
|
| 439 |
+
"n_samples": n_mc_samples,
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
# ── Assemble cMSCI ─────────────────────────────────────
|
| 443 |
+
# cMSCI is the highest available variant
|
| 444 |
+
cmsci = variant_f
|
| 445 |
+
active_variant = "F"
|
| 446 |
+
|
| 447 |
+
if variant_f == variant_e:
|
| 448 |
+
active_variant = "E" if variant_e != variant_d else "D"
|
| 449 |
+
if variant_e == variant_d:
|
| 450 |
+
active_variant = "D" if variant_d != variant_c else "C"
|
| 451 |
+
if variant_d == variant_c:
|
| 452 |
+
active_variant = "C" if variant_c != variant_b else "B"
|
| 453 |
+
if variant_c == variant_b:
|
| 454 |
+
active_variant = "B" if variant_b is not None else "A"
|
| 455 |
+
|
| 456 |
+
# Final cMSCI: use the most sophisticated available variant
|
| 457 |
+
if cmsci is None:
|
| 458 |
+
cmsci = msci # fallback to legacy
|
| 459 |
+
active_variant = "A"
|
| 460 |
+
|
| 461 |
+
logger.info(
|
| 462 |
+
"cMSCI = %.4f (variant %s) | MSCI = %s",
|
| 463 |
+
cmsci if cmsci is not None else 0.0,
|
| 464 |
+
active_variant,
|
| 465 |
+
msci,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return {
|
| 469 |
+
"cmsci": round(cmsci, 4) if cmsci is not None else None,
|
| 470 |
+
"msci": round(msci, 4) if msci is not None else None,
|
| 471 |
+
"active_variant": active_variant,
|
| 472 |
+
"scores": {
|
| 473 |
+
"st_i": st_i,
|
| 474 |
+
"st_a": st_a,
|
| 475 |
+
"si_a": si_a,
|
| 476 |
+
},
|
| 477 |
+
"gram": {
|
| 478 |
+
"text_image": round(gram_ti, 4) if gram_ti is not None else None,
|
| 479 |
+
"text_audio": round(gram_ta, 4) if gram_ta is not None else None,
|
| 480 |
+
"text_image_audio": round(gram_tia, 4) if gram_tia is not None else None,
|
| 481 |
+
"coherence_2way": round(gram_coherence_2way, 4) if gram_coherence_2way is not None else None,
|
| 482 |
+
},
|
| 483 |
+
"calibration": {
|
| 484 |
+
"z_st_i": round(z_st_i, 4) if z_st_i is not None else None,
|
| 485 |
+
"z_st_a": round(z_st_a, 4) if z_st_a is not None else None,
|
| 486 |
+
"z_gram_ti": round(z_gram_ti, 4) if z_gram_ti is not None else None,
|
| 487 |
+
"z_gram_ta": round(z_gram_ta, 4) if z_gram_ta is not None else None,
|
| 488 |
+
"z_compl": round(z_compl, 4) if z_compl is not None else None,
|
| 489 |
+
"gram_ia_volume": round(gram_ia_volume, 4) if gram_ia_volume is not None else None,
|
| 490 |
+
"u_ti": round(u_ti, 6) if u_ti is not None else None,
|
| 491 |
+
"u_ta": round(u_ta, 6) if u_ta is not None else None,
|
| 492 |
+
"adaptive_w_ti": round(adaptive_w_ti, 4) if adaptive_w_ti is not None else None,
|
| 493 |
+
"cal_mode": cal_mode if self._calibration is not None else None,
|
| 494 |
+
"w_ti": w_ti,
|
| 495 |
+
"w_3d": w_3d,
|
| 496 |
+
"gamma": gamma,
|
| 497 |
+
"margin_alpha": CMSCI_MARGIN_ALPHA if contrastive_result else None,
|
| 498 |
+
},
|
| 499 |
+
"contrastive": contrastive_result,
|
| 500 |
+
"uncertainty": uncertainty_result,
|
| 501 |
+
"variant_scores": {
|
| 502 |
+
"A_msci": round(variant_a, 4) if variant_a is not None else None,
|
| 503 |
+
"B_gram": round(variant_b, 4) if variant_b is not None else None,
|
| 504 |
+
"C_gram_znorm": round(variant_c, 4) if variant_c is not None else None,
|
| 505 |
+
"D_gram_znorm_contrastive": round(variant_d, 4) if variant_d is not None else None,
|
| 506 |
+
"E_gram_znorm_contrastive_exmcr": round(variant_e, 4) if variant_e is not None else None,
|
| 507 |
+
"F_full_cmsci": round(variant_f, 4) if variant_f is not None else None,
|
| 508 |
+
},
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
def evaluate_batch(
|
| 512 |
+
self,
|
| 513 |
+
items: List[Dict[str, str]],
|
| 514 |
+
n_mc_samples: int = 100,
|
| 515 |
+
) -> List[Dict[str, Any]]:
|
| 516 |
+
"""
|
| 517 |
+
Evaluate a batch of (text, image_path, audio_path) triples.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
items: List of dicts with keys "text", "image_path", "audio_path", "domain".
|
| 521 |
+
n_mc_samples: MC samples per item.
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
List of result dicts from evaluate().
|
| 525 |
+
"""
|
| 526 |
+
results = []
|
| 527 |
+
for item in items:
|
| 528 |
+
result = self.evaluate(
|
| 529 |
+
text=item.get("text", ""),
|
| 530 |
+
image_path=item.get("image_path"),
|
| 531 |
+
audio_path=item.get("audio_path"),
|
| 532 |
+
domain=item.get("domain", ""),
|
| 533 |
+
n_mc_samples=n_mc_samples,
|
| 534 |
+
)
|
| 535 |
+
results.append(result)
|
| 536 |
+
return results
|
src/coherence/gram_volume.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gramian Volume Scoring for Multimodal Coherence.
|
| 3 |
+
|
| 4 |
+
The Gramian volume measures the geometric dispersion of embedding vectors.
|
| 5 |
+
For n L2-normalized vectors, the Gramian matrix G has G_ij = <vi, vj>.
|
| 6 |
+
|
| 7 |
+
volume = sqrt(det(G))
|
| 8 |
+
|
| 9 |
+
Properties:
|
| 10 |
+
- Identical vectors → det(G) = 0 → volume = 0 (perfect alignment)
|
| 11 |
+
- Mutually orthogonal unit vectors → det(G) = 1 → volume = 1 (max dispersion)
|
| 12 |
+
- Coherence = 1 - volume → [0, 1] where 1 = perfect alignment
|
| 13 |
+
|
| 14 |
+
For 2 unit vectors:
|
| 15 |
+
det(G) = 1 - cos²(θ) = sin²(θ)
|
| 16 |
+
volume = |sin(θ)|
|
| 17 |
+
coherence = 1 - |sin(θ)| ≈ cos(θ) for small angles
|
| 18 |
+
|
| 19 |
+
For 3 unit vectors:
|
| 20 |
+
det(G) = 1 - cos²(a) - cos²(b) - cos²(c) + 2·cos(a)·cos(b)·cos(c)
|
| 21 |
+
where a, b, c are pairwise angles
|
| 22 |
+
This captures the full tri-modal geometric relationship in one number.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _normalize(v: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
| 31 |
+
"""L2-normalize a vector."""
|
| 32 |
+
v = v.astype(np.float64).squeeze()
|
| 33 |
+
norm = np.linalg.norm(v) + eps
|
| 34 |
+
return v / norm
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def gram_volume_2d(v1: np.ndarray, v2: np.ndarray) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Gramian volume for 2 vectors (area of parallelogram).
|
| 40 |
+
|
| 41 |
+
For unit vectors: volume = |sin(θ)| where θ is the angle between them.
|
| 42 |
+
Range: [0, 1] — 0 when identical, 1 when orthogonal.
|
| 43 |
+
"""
|
| 44 |
+
v1_n = _normalize(v1)
|
| 45 |
+
v2_n = _normalize(v2)
|
| 46 |
+
cos_sim = np.clip(np.dot(v1_n, v2_n), -1.0, 1.0)
|
| 47 |
+
# det(G) = 1 - cos²(θ)
|
| 48 |
+
det_g = 1.0 - cos_sim ** 2
|
| 49 |
+
return float(np.sqrt(max(det_g, 0.0)))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def gram_volume_3d(
|
| 53 |
+
v1: np.ndarray, v2: np.ndarray, v3: np.ndarray,
|
| 54 |
+
) -> float:
|
| 55 |
+
"""
|
| 56 |
+
Gramian volume for 3 vectors (volume of parallelepiped).
|
| 57 |
+
|
| 58 |
+
For unit vectors with pairwise cosines a, b, c:
|
| 59 |
+
det(G) = 1 - a² - b² - c² + 2abc
|
| 60 |
+
|
| 61 |
+
Range: [0, 1] — 0 when all collinear, 1 when mutually orthogonal.
|
| 62 |
+
"""
|
| 63 |
+
v1_n = _normalize(v1)
|
| 64 |
+
v2_n = _normalize(v2)
|
| 65 |
+
v3_n = _normalize(v3)
|
| 66 |
+
|
| 67 |
+
a = np.dot(v1_n, v2_n)
|
| 68 |
+
b = np.dot(v1_n, v3_n)
|
| 69 |
+
c = np.dot(v2_n, v3_n)
|
| 70 |
+
|
| 71 |
+
det_g = 1.0 - a**2 - b**2 - c**2 + 2.0 * a * b * c
|
| 72 |
+
return float(np.sqrt(max(det_g, 0.0)))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def gram_volume_nd(*vectors: np.ndarray) -> float:
|
| 76 |
+
"""
|
| 77 |
+
Gramian volume for n vectors (general case).
|
| 78 |
+
|
| 79 |
+
Builds the Gram matrix G_ij = <vi, vj> from L2-normalized vectors
|
| 80 |
+
and returns sqrt(det(G)).
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
*vectors: Variable number of numpy arrays (embeddings).
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Gramian volume in [0, 1] for unit vectors.
|
| 87 |
+
"""
|
| 88 |
+
n = len(vectors)
|
| 89 |
+
if n == 0:
|
| 90 |
+
return 0.0
|
| 91 |
+
if n == 1:
|
| 92 |
+
return 0.0
|
| 93 |
+
if n == 2:
|
| 94 |
+
return gram_volume_2d(vectors[0], vectors[1])
|
| 95 |
+
if n == 3:
|
| 96 |
+
return gram_volume_3d(vectors[0], vectors[1], vectors[2])
|
| 97 |
+
|
| 98 |
+
normed = [_normalize(v) for v in vectors]
|
| 99 |
+
G = np.zeros((n, n), dtype=np.float64)
|
| 100 |
+
for i in range(n):
|
| 101 |
+
for j in range(i, n):
|
| 102 |
+
dot = np.dot(normed[i], normed[j])
|
| 103 |
+
G[i, j] = dot
|
| 104 |
+
G[j, i] = dot
|
| 105 |
+
|
| 106 |
+
det_g = np.linalg.det(G)
|
| 107 |
+
return float(np.sqrt(max(det_g, 0.0)))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def normalized_gram_coherence(volume: float, n_vectors: int = 2) -> float:
|
| 111 |
+
"""
|
| 112 |
+
Map Gramian volume to coherence score in [0, 1].
|
| 113 |
+
|
| 114 |
+
1 = perfect alignment (volume = 0, all vectors identical)
|
| 115 |
+
0 = maximum dispersion (volume = 1, mutually orthogonal)
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
volume: Gramian volume (output of gram_volume_* functions).
|
| 119 |
+
n_vectors: Number of vectors used (for documentation; mapping is the same).
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Coherence score in [0, 1].
|
| 123 |
+
"""
|
| 124 |
+
return float(max(0.0, min(1.0, 1.0 - volume)))
|
src/coherence/negative_bank.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Contrastive Negative Bank for cMSCI Calibration.
|
| 3 |
+
|
| 4 |
+
Computes contrastive margins by comparing a matched (text, image, audio)
|
| 5 |
+
triple against hard-negative alternatives from the embedding indexes.
|
| 6 |
+
|
| 7 |
+
A positive contrastive margin means the matched triple has tighter
|
| 8 |
+
geometric coherence than mismatched alternatives — the defining
|
| 9 |
+
property of a well-calibrated metric.
|
| 10 |
+
|
| 11 |
+
Contrastive margin:
|
| 12 |
+
margin = mean(neg_volumes) - matched_volume
|
| 13 |
+
> 0 → matched triple is more coherent than negatives (good)
|
| 14 |
+
≤ 0 → metric cannot distinguish matched from mismatched (bad)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from src.coherence.gram_volume import gram_volume_2d, gram_volume_3d, normalized_gram_coherence
|
| 26 |
+
from src.embeddings.similarity import l2_normalize
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class NegativeBank:
|
| 32 |
+
"""
|
| 33 |
+
Loads pre-computed embedding indexes and provides hard negatives.
|
| 34 |
+
|
| 35 |
+
Hard negatives are embeddings with high individual similarity to the
|
| 36 |
+
query but from a different domain — the most challenging cases for
|
| 37 |
+
the coherence metric.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
image_index_path: str = "data/embeddings/image_index.npz",
|
| 43 |
+
audio_index_path: str = "data/embeddings/audio_index.npz",
|
| 44 |
+
):
|
| 45 |
+
self._image_ids: Optional[np.ndarray] = None
|
| 46 |
+
self._image_embs: Optional[np.ndarray] = None
|
| 47 |
+
self._image_domains: Optional[np.ndarray] = None
|
| 48 |
+
self._audio_ids: Optional[np.ndarray] = None
|
| 49 |
+
self._audio_embs: Optional[np.ndarray] = None
|
| 50 |
+
self._audio_domains: Optional[np.ndarray] = None
|
| 51 |
+
|
| 52 |
+
self._load_index(image_index_path, "image")
|
| 53 |
+
self._load_index(audio_index_path, "audio")
|
| 54 |
+
|
| 55 |
+
def _load_index(self, path: str, modality: str) -> None:
|
| 56 |
+
p = Path(path)
|
| 57 |
+
if not p.exists():
|
| 58 |
+
logger.warning("Index not found: %s — %s negatives disabled", path, modality)
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
data = np.load(path, allow_pickle=True)
|
| 62 |
+
ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
|
| 63 |
+
embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
|
| 64 |
+
domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
|
| 65 |
+
|
| 66 |
+
if modality == "image":
|
| 67 |
+
self._image_ids = ids
|
| 68 |
+
self._image_embs = embs.astype(np.float32)
|
| 69 |
+
self._image_domains = domains
|
| 70 |
+
logger.info("Loaded image index: %d entries", len(ids))
|
| 71 |
+
else:
|
| 72 |
+
self._audio_ids = ids
|
| 73 |
+
self._audio_embs = embs.astype(np.float32)
|
| 74 |
+
self._audio_domains = domains
|
| 75 |
+
logger.info("Loaded audio index: %d entries", len(ids))
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def has_images(self) -> bool:
|
| 79 |
+
return self._image_embs is not None and len(self._image_embs) > 0
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def has_audio(self) -> bool:
|
| 83 |
+
return self._audio_embs is not None and len(self._audio_embs) > 0
|
| 84 |
+
|
| 85 |
+
def get_hard_negative_images(
|
| 86 |
+
self,
|
| 87 |
+
text_emb: np.ndarray,
|
| 88 |
+
exclude_domain: str = "",
|
| 89 |
+
k: int = 5,
|
| 90 |
+
) -> List[np.ndarray]:
|
| 91 |
+
"""
|
| 92 |
+
Get top-k hardest negative images (high text similarity but wrong domain).
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
text_emb: CLIP text embedding for the query.
|
| 96 |
+
exclude_domain: Domain to exclude (the correct domain).
|
| 97 |
+
k: Number of negatives to return.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
List of image embeddings (hard negatives).
|
| 101 |
+
"""
|
| 102 |
+
if not self.has_images:
|
| 103 |
+
return []
|
| 104 |
+
|
| 105 |
+
text_n = l2_normalize(text_emb.squeeze())
|
| 106 |
+
sims = self._image_embs @ text_n
|
| 107 |
+
|
| 108 |
+
# Filter by domain: exclude the matched domain
|
| 109 |
+
if exclude_domain:
|
| 110 |
+
mask = np.array([d != exclude_domain for d in self._image_domains])
|
| 111 |
+
else:
|
| 112 |
+
mask = np.ones(len(sims), dtype=bool)
|
| 113 |
+
|
| 114 |
+
sims_masked = np.where(mask, sims, -np.inf)
|
| 115 |
+
top_k_idx = np.argsort(sims_masked)[-k:][::-1]
|
| 116 |
+
|
| 117 |
+
return [self._image_embs[i] for i in top_k_idx if sims_masked[i] > -np.inf]
|
| 118 |
+
|
| 119 |
+
def get_hard_negative_audio(
|
| 120 |
+
self,
|
| 121 |
+
text_emb: np.ndarray,
|
| 122 |
+
exclude_domain: str = "",
|
| 123 |
+
k: int = 5,
|
| 124 |
+
) -> List[np.ndarray]:
|
| 125 |
+
"""
|
| 126 |
+
Get top-k hardest negative audio (high text similarity but wrong domain).
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
text_emb: CLAP text embedding for the query.
|
| 130 |
+
exclude_domain: Domain to exclude.
|
| 131 |
+
k: Number of negatives to return.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List of audio embeddings (hard negatives).
|
| 135 |
+
"""
|
| 136 |
+
if not self.has_audio:
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
text_n = l2_normalize(text_emb.squeeze())
|
| 140 |
+
sims = self._audio_embs @ text_n
|
| 141 |
+
|
| 142 |
+
if exclude_domain:
|
| 143 |
+
mask = np.array([d != exclude_domain for d in self._audio_domains])
|
| 144 |
+
else:
|
| 145 |
+
mask = np.ones(len(sims), dtype=bool)
|
| 146 |
+
|
| 147 |
+
sims_masked = np.where(mask, sims, -np.inf)
|
| 148 |
+
top_k_idx = np.argsort(sims_masked)[-k:][::-1]
|
| 149 |
+
|
| 150 |
+
return [self._audio_embs[i] for i in top_k_idx if sims_masked[i] > -np.inf]
|
| 151 |
+
|
| 152 |
+
def compute_contrastive_margin(
|
| 153 |
+
self,
|
| 154 |
+
matched_volume: float,
|
| 155 |
+
text_clip_emb: np.ndarray,
|
| 156 |
+
image_emb: np.ndarray,
|
| 157 |
+
text_clap_emb: Optional[np.ndarray] = None,
|
| 158 |
+
audio_emb: Optional[np.ndarray] = None,
|
| 159 |
+
domain: str = "",
|
| 160 |
+
k: int = 5,
|
| 161 |
+
) -> Dict[str, float]:
|
| 162 |
+
"""
|
| 163 |
+
Compute contrastive margin against hard negatives.
|
| 164 |
+
|
| 165 |
+
For each hard negative, computes the gram volume of the negative
|
| 166 |
+
triple and averages. Margin = mean(neg_volumes) - matched_volume.
|
| 167 |
+
|
| 168 |
+
A positive margin means the matched triple is geometrically tighter
|
| 169 |
+
than hard-negative alternatives.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
matched_volume: Gram volume of the matched (text, image, audio) triple.
|
| 173 |
+
text_clip_emb: CLIP text embedding (for finding negative images).
|
| 174 |
+
image_emb: CLIP image embedding of the matched image.
|
| 175 |
+
text_clap_emb: CLAP text embedding (for finding negative audio).
|
| 176 |
+
audio_emb: CLAP audio embedding of the matched audio.
|
| 177 |
+
domain: Domain of the matched prompt (excluded from negatives).
|
| 178 |
+
k: Number of hard negatives per modality.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Dict with margin, mean_neg_volume, n_negatives.
|
| 182 |
+
"""
|
| 183 |
+
neg_volumes = []
|
| 184 |
+
|
| 185 |
+
# Image negatives: replace matched image with hard negative
|
| 186 |
+
neg_images = self.get_hard_negative_images(text_clip_emb, domain, k)
|
| 187 |
+
for neg_img in neg_images:
|
| 188 |
+
vol = gram_volume_2d(text_clip_emb, neg_img)
|
| 189 |
+
neg_volumes.append(vol)
|
| 190 |
+
|
| 191 |
+
# Audio negatives: replace matched audio with hard negative
|
| 192 |
+
if text_clap_emb is not None:
|
| 193 |
+
neg_audios = self.get_hard_negative_audio(text_clap_emb, domain, k)
|
| 194 |
+
for neg_aud in neg_audios:
|
| 195 |
+
vol = gram_volume_2d(text_clap_emb, neg_aud)
|
| 196 |
+
neg_volumes.append(vol)
|
| 197 |
+
|
| 198 |
+
if not neg_volumes:
|
| 199 |
+
return {
|
| 200 |
+
"margin": 0.0,
|
| 201 |
+
"mean_neg_volume": matched_volume,
|
| 202 |
+
"n_negatives": 0,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
mean_neg = float(np.mean(neg_volumes))
|
| 206 |
+
margin = mean_neg - matched_volume
|
| 207 |
+
|
| 208 |
+
return {
|
| 209 |
+
"margin": float(margin),
|
| 210 |
+
"mean_neg_volume": mean_neg,
|
| 211 |
+
"n_negatives": len(neg_volumes),
|
| 212 |
+
}
|
src/config/settings.py
CHANGED
|
@@ -106,3 +106,47 @@ DRIFT_ASYMMETRY_THRESHOLD = 0.15 # |st_i - st_a| gap to flag drift
|
|
| 106 |
RERATING_FRACTION = 0.20
|
| 107 |
KAPPA_ACCEPTABLE_THRESHOLD = 0.70
|
| 108 |
ALPHA_ACCEPTABLE_THRESHOLD = 0.667
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
RERATING_FRACTION = 0.20
|
| 107 |
KAPPA_ACCEPTABLE_THRESHOLD = 0.70
|
| 108 |
ALPHA_ACCEPTABLE_THRESHOLD = 0.667
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# cMSCI (Calibrated Multimodal Semantic Coherence Index)
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
# Calibration store (fitted from RQ1 baseline data)
|
| 115 |
+
CMSCI_CALIBRATION_PATH = PROJECT_ROOT / "artifacts" / "cmsci_calibration.json"
|
| 116 |
+
|
| 117 |
+
# Ex-MCR cross-space alignment (CLAP → CLIP projection)
|
| 118 |
+
EXMCR_WEIGHTS_PATH = PROJECT_ROOT / "models" / "exmcr" / "ex_clap.pt"
|
| 119 |
+
|
| 120 |
+
# Cross-Space Bridge (CLIP image + CLAP audio → shared 256-d bridge space)
|
| 121 |
+
BRIDGE_WEIGHTS_PATH = PROJECT_ROOT / "models" / "bridge" / "bridge_best.pt"
|
| 122 |
+
|
| 123 |
+
# Probabilistic adapters (ProbVLM-style uncertainty)
|
| 124 |
+
PROB_CLIP_ADAPTER_PATH = PROJECT_ROOT / "models" / "prob_adapters" / "clip_adapter.pt"
|
| 125 |
+
PROB_CLAP_ADAPTER_PATH = PROJECT_ROOT / "models" / "prob_adapters" / "clap_adapter.pt"
|
| 126 |
+
|
| 127 |
+
# Full pipeline optimized parameters (via LOO-CV on RQ3 human ratings)
|
| 128 |
+
# Full-sample rho=0.608 (p=0.0004), LOO-CV rho=0.546 (p=0.0018), overfit gap=0.001
|
| 129 |
+
# Selected in 87% of LOO folds (26/30) — highly stable
|
| 130 |
+
CMSCI_MARGIN_ALPHA = 16 # Margin scaling factor (amplifies contrastive signal)
|
| 131 |
+
CMSCI_CHANNEL_WEIGHT_TI = 0.90 # Text-image channel weight (1 - w for text-audio)
|
| 132 |
+
CMSCI_CALIBRATION_MODE = "gram" # "cosine" (z-norm cosine sims) or "gram" (z-norm gram coherences)
|
| 133 |
+
|
| 134 |
+
# Variant E: ExMCR cross-modal complementarity (w_3d=0 recovers D exactly)
|
| 135 |
+
# ExMCR projects CLAP audio → CLIP space; complementarity = Gramian dispersion
|
| 136 |
+
# High complementarity = image and audio contribute unique perspectives (rewarded)
|
| 137 |
+
CMSCI_W_3D = 0.45 # Weight for z-normalized IA complementarity
|
| 138 |
+
# Variant F: ProbVLM adaptive channel weighting (gamma=0 recovers E exactly)
|
| 139 |
+
CMSCI_GAMMA = 0.10 # Mixing ratio: w_final = (1-gamma)*base_w + gamma*adaptive_w
|
| 140 |
+
|
| 141 |
+
# Contrastive negative bank
|
| 142 |
+
CMSCI_NEGATIVE_K = 5 # Number of hard negatives per modality
|
| 143 |
+
CMSCI_NEGATIVE_BANK_ENABLED = True # Enable/disable contrastive calibration
|
| 144 |
+
|
| 145 |
+
# MC sampling for uncertainty estimation
|
| 146 |
+
CMSCI_MC_SAMPLES = 100 # Number of Monte Carlo samples for Variant F
|
| 147 |
+
|
| 148 |
+
# Probabilistic adapter training
|
| 149 |
+
PROB_ADAPTER_EPOCHS = 100
|
| 150 |
+
PROB_ADAPTER_LR = 1e-4
|
| 151 |
+
PROB_ADAPTER_BATCH_SIZE = 32
|
| 152 |
+
PROB_ADAPTER_PATIENCE = 15
|
src/embeddings/prob_adapter_trainer.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Loop for ProbVLM-Style Probabilistic Adapters.
|
| 3 |
+
|
| 4 |
+
Trains lightweight post-hoc adapters on top of frozen CLIP/CLAP encoders.
|
| 5 |
+
Each adapter learns to predict uncertainty (Generalized Gaussian parameters)
|
| 6 |
+
for a single embedding space.
|
| 7 |
+
|
| 8 |
+
Two adapters to train:
|
| 9 |
+
1. CLIP adapter: trained on (image_embedding, text_embedding) pairs
|
| 10 |
+
2. CLAP adapter: trained on (audio_embedding, text_embedding) pairs
|
| 11 |
+
|
| 12 |
+
Training data:
|
| 13 |
+
- Our 57 images paired with text descriptions (CLIP pairs)
|
| 14 |
+
- Our 104 audio files paired with text descriptions (CLAP pairs)
|
| 15 |
+
- All 30 RQ1 prompts × matched media as additional pairs
|
| 16 |
+
|
| 17 |
+
Loss:
|
| 18 |
+
L = L1(mu, target) + GenGaussLoss(mu, alpha, beta, target)
|
| 19 |
+
|
| 20 |
+
GenGaussLoss:
|
| 21 |
+
-log p(target | mu, alpha, beta) ∝ log(alpha) - log(beta) + (|target - mu| / alpha)^beta
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Dict, List, Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 39 |
+
TORCH_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
TORCH_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class EmbeddingPairDataset(Dataset):
|
| 47 |
+
"""Dataset of (input_embedding, target_embedding) pairs."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, inputs: np.ndarray, targets: np.ndarray):
|
| 50 |
+
if not TORCH_AVAILABLE:
|
| 51 |
+
raise ImportError("PyTorch required")
|
| 52 |
+
assert len(inputs) == len(targets)
|
| 53 |
+
self.inputs = torch.tensor(inputs, dtype=torch.float32)
|
| 54 |
+
self.targets = torch.tensor(targets, dtype=torch.float32)
|
| 55 |
+
|
| 56 |
+
def __len__(self) -> int:
|
| 57 |
+
return len(self.inputs)
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
return self.inputs[idx], self.targets[idx]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GenGaussNLL(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Negative log-likelihood loss for Generalized Gaussian distribution.
|
| 66 |
+
|
| 67 |
+
-log p(x | mu, alpha, beta) = log(2*alpha) + log(Gamma(1/beta)/beta) + (|x - mu| / alpha)^beta
|
| 68 |
+
|
| 69 |
+
Simplified (dropping constant terms):
|
| 70 |
+
L = log(alpha) + (|target - mu| / alpha)^beta
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
mu: torch.Tensor,
|
| 76 |
+
alpha: torch.Tensor,
|
| 77 |
+
beta: torch.Tensor,
|
| 78 |
+
target: torch.Tensor,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
residual = torch.abs(target - mu)
|
| 81 |
+
# Clamp alpha to avoid division by zero
|
| 82 |
+
alpha_c = torch.clamp(alpha, min=1e-6)
|
| 83 |
+
nll = torch.log(alpha_c) + (residual / alpha_c).pow(beta)
|
| 84 |
+
return nll.mean()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def train_prob_adapter(
|
| 88 |
+
input_embeddings: np.ndarray,
|
| 89 |
+
target_embeddings: np.ndarray,
|
| 90 |
+
epochs: int = 100,
|
| 91 |
+
lr: float = 1e-4,
|
| 92 |
+
batch_size: int = 32,
|
| 93 |
+
val_split: float = 0.15,
|
| 94 |
+
patience: int = 15,
|
| 95 |
+
output_path: Optional[str] = None,
|
| 96 |
+
adapter_name: str = "adapter",
|
| 97 |
+
) -> ProbabilisticAdapter:
|
| 98 |
+
"""
|
| 99 |
+
Train a ProbabilisticAdapter on paired embeddings.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
input_embeddings: Source embeddings [N, 512] (e.g. image CLIP or audio CLAP).
|
| 103 |
+
target_embeddings: Target embeddings [N, 512] (e.g. text CLIP or text CLAP).
|
| 104 |
+
epochs: Maximum training epochs.
|
| 105 |
+
lr: Learning rate.
|
| 106 |
+
batch_size: Batch size.
|
| 107 |
+
val_split: Fraction for validation.
|
| 108 |
+
patience: Early stopping patience.
|
| 109 |
+
output_path: If set, save best model here.
|
| 110 |
+
adapter_name: Name for logging.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Trained ProbabilisticAdapter.
|
| 114 |
+
"""
|
| 115 |
+
if not TORCH_AVAILABLE:
|
| 116 |
+
raise ImportError("PyTorch required for training")
|
| 117 |
+
|
| 118 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 119 |
+
|
| 120 |
+
# Build dataset
|
| 121 |
+
dataset = EmbeddingPairDataset(input_embeddings, target_embeddings)
|
| 122 |
+
n_val = max(1, int(len(dataset) * val_split))
|
| 123 |
+
n_train = len(dataset) - n_val
|
| 124 |
+
train_ds, val_ds = random_split(
|
| 125 |
+
dataset, [n_train, n_val],
|
| 126 |
+
generator=torch.Generator().manual_seed(42),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=len(train_ds) > batch_size)
|
| 130 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
| 131 |
+
|
| 132 |
+
# Build model
|
| 133 |
+
input_dim = input_embeddings.shape[1]
|
| 134 |
+
adapter = ProbabilisticAdapter(input_dim=input_dim).to(device)
|
| 135 |
+
optimizer = torch.optim.AdamW(adapter.parameters(), lr=lr, weight_decay=1e-4)
|
| 136 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
| 137 |
+
|
| 138 |
+
l1_loss = nn.L1Loss()
|
| 139 |
+
gg_loss = GenGaussNLL()
|
| 140 |
+
|
| 141 |
+
best_val_loss = float("inf")
|
| 142 |
+
patience_counter = 0
|
| 143 |
+
|
| 144 |
+
logger.info(
|
| 145 |
+
"Training %s adapter: %d train, %d val, %d epochs, device=%s",
|
| 146 |
+
adapter_name, n_train, n_val, epochs, device,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
for epoch in range(epochs):
|
| 150 |
+
# Train
|
| 151 |
+
adapter.train()
|
| 152 |
+
train_losses = []
|
| 153 |
+
for inp, tgt in train_loader:
|
| 154 |
+
inp, tgt = inp.to(device), tgt.to(device)
|
| 155 |
+
optimizer.zero_grad()
|
| 156 |
+
|
| 157 |
+
mu, alpha, beta = adapter(inp)
|
| 158 |
+
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
|
| 159 |
+
loss.backward()
|
| 160 |
+
torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0)
|
| 161 |
+
optimizer.step()
|
| 162 |
+
train_losses.append(loss.item())
|
| 163 |
+
|
| 164 |
+
scheduler.step()
|
| 165 |
+
|
| 166 |
+
# Validate
|
| 167 |
+
adapter.eval()
|
| 168 |
+
val_losses = []
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
for inp, tgt in val_loader:
|
| 171 |
+
inp, tgt = inp.to(device), tgt.to(device)
|
| 172 |
+
mu, alpha, beta = adapter(inp)
|
| 173 |
+
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
|
| 174 |
+
val_losses.append(loss.item())
|
| 175 |
+
|
| 176 |
+
avg_train = np.mean(train_losses)
|
| 177 |
+
avg_val = np.mean(val_losses) if val_losses else float("inf")
|
| 178 |
+
|
| 179 |
+
if (epoch + 1) % 10 == 0 or epoch == 0:
|
| 180 |
+
logger.info(
|
| 181 |
+
" [%s] Epoch %d/%d: train=%.4f, val=%.4f",
|
| 182 |
+
adapter_name, epoch + 1, epochs, avg_train, avg_val,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Early stopping
|
| 186 |
+
if avg_val < best_val_loss:
|
| 187 |
+
best_val_loss = avg_val
|
| 188 |
+
patience_counter = 0
|
| 189 |
+
if output_path:
|
| 190 |
+
adapter.save(output_path)
|
| 191 |
+
else:
|
| 192 |
+
patience_counter += 1
|
| 193 |
+
if patience_counter >= patience:
|
| 194 |
+
logger.info(" [%s] Early stopping at epoch %d", adapter_name, epoch + 1)
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
# Load best if saved
|
| 198 |
+
if output_path and Path(output_path).exists():
|
| 199 |
+
adapter = ProbabilisticAdapter.load(output_path)
|
| 200 |
+
adapter = adapter.to(device)
|
| 201 |
+
else:
|
| 202 |
+
adapter = adapter.cpu()
|
| 203 |
+
|
| 204 |
+
adapter.eval()
|
| 205 |
+
logger.info(" [%s] Training complete. Best val_loss=%.4f", adapter_name, best_val_loss)
|
| 206 |
+
return adapter
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def build_training_pairs_from_index(
|
| 210 |
+
embedding_index_path: str,
|
| 211 |
+
text_embedder_fn,
|
| 212 |
+
modality: str = "image",
|
| 213 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 214 |
+
"""
|
| 215 |
+
Build (media_embedding, text_embedding) pairs from an embedding index.
|
| 216 |
+
|
| 217 |
+
For each media file in the index, generates a text description from
|
| 218 |
+
the filename/metadata and embeds it.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
embedding_index_path: Path to image_index.npz or audio_index.npz.
|
| 222 |
+
text_embedder_fn: Function that takes text -> np.ndarray embedding.
|
| 223 |
+
modality: "image" for CLIP text, "audio" for CLAP text.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
(media_embeddings, text_embeddings) both shape [N, 512].
|
| 227 |
+
"""
|
| 228 |
+
data = np.load(embedding_index_path, allow_pickle=True)
|
| 229 |
+
ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
|
| 230 |
+
embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
|
| 231 |
+
domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
|
| 232 |
+
|
| 233 |
+
media_embs = []
|
| 234 |
+
text_embs = []
|
| 235 |
+
|
| 236 |
+
for i, (file_id, domain) in enumerate(zip(ids, domains)):
|
| 237 |
+
# Generate caption from filename
|
| 238 |
+
name = Path(str(file_id)).stem
|
| 239 |
+
# Clean up filename to make a caption
|
| 240 |
+
caption = name.replace("_", " ").replace("-", " ")
|
| 241 |
+
# Remove common prefixes
|
| 242 |
+
for prefix in ["fs ", "wm ", "proc "]:
|
| 243 |
+
if caption.lower().startswith(prefix):
|
| 244 |
+
caption = caption[len(prefix):]
|
| 245 |
+
# Add domain context
|
| 246 |
+
if domain != "other":
|
| 247 |
+
caption = f"{domain}: {caption}"
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
text_emb = text_embedder_fn(caption)
|
| 251 |
+
media_embs.append(embs[i])
|
| 252 |
+
text_embs.append(text_emb)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.warning("Skipping %s: %s", file_id, e)
|
| 255 |
+
|
| 256 |
+
return np.array(media_embs), np.array(text_embs)
|
src/embeddings/probabilistic_adapter.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ProbVLM-Style Probabilistic Adapter for Uncertainty Estimation.
|
| 3 |
+
|
| 4 |
+
Converts point embeddings into distributions (Generalized Gaussian)
|
| 5 |
+
following the BayesCap approach from ProbVLM.
|
| 6 |
+
|
| 7 |
+
Each adapter takes a frozen embedding and predicts:
|
| 8 |
+
mu: Shift from the input embedding (residual)
|
| 9 |
+
alpha: Scale parameter (controls spread)
|
| 10 |
+
beta: Shape parameter (controls tail behavior)
|
| 11 |
+
|
| 12 |
+
These define a Generalized Gaussian distribution:
|
| 13 |
+
p(x) ∝ exp(-(|x - mu| / alpha)^beta)
|
| 14 |
+
|
| 15 |
+
MC sampling from this distribution produces N embedding samples,
|
| 16 |
+
which propagate uncertainty through the Gramian volume computation.
|
| 17 |
+
|
| 18 |
+
Architecture: BayesCap_MLP
|
| 19 |
+
input → Linear(d, hidden) → ReLU → Dropout
|
| 20 |
+
→ Linear(hidden, hidden) → ReLU → Dropout
|
| 21 |
+
→ Three heads: mu_head, alpha_head, beta_head
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Dict, Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
TORCH_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
TORCH_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _check_torch():
|
| 44 |
+
if not TORCH_AVAILABLE:
|
| 45 |
+
raise ImportError("PyTorch required for ProbabilisticAdapter")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ProbabilisticAdapter(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
BayesCap-style adapter that maps point embeddings to distributions.
|
| 51 |
+
|
| 52 |
+
Takes a frozen embedding (from CLIP or CLAP) and predicts
|
| 53 |
+
Generalized Gaussian parameters: (mu, alpha, beta).
|
| 54 |
+
|
| 55 |
+
The adapter is lightweight (~0.5M params) and trains in minutes
|
| 56 |
+
on small datasets.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
input_dim: int = 512,
|
| 62 |
+
hidden_dim: int = 256,
|
| 63 |
+
num_layers: int = 3,
|
| 64 |
+
dropout: float = 0.1,
|
| 65 |
+
):
|
| 66 |
+
_check_torch()
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.input_dim = input_dim
|
| 70 |
+
|
| 71 |
+
# Shared backbone
|
| 72 |
+
layers = []
|
| 73 |
+
in_d = input_dim
|
| 74 |
+
for _ in range(num_layers - 1):
|
| 75 |
+
layers.extend([
|
| 76 |
+
nn.Linear(in_d, hidden_dim),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
nn.Dropout(dropout),
|
| 79 |
+
])
|
| 80 |
+
in_d = hidden_dim
|
| 81 |
+
self.backbone = nn.Sequential(*layers)
|
| 82 |
+
|
| 83 |
+
# Three output heads
|
| 84 |
+
self.mu_head = nn.Linear(hidden_dim, input_dim)
|
| 85 |
+
self.alpha_head = nn.Linear(hidden_dim, input_dim)
|
| 86 |
+
self.beta_head = nn.Linear(hidden_dim, input_dim)
|
| 87 |
+
|
| 88 |
+
self.config = {
|
| 89 |
+
"input_dim": input_dim,
|
| 90 |
+
"hidden_dim": hidden_dim,
|
| 91 |
+
"num_layers": num_layers,
|
| 92 |
+
"dropout": dropout,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self, embedding: torch.Tensor,
|
| 97 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 98 |
+
"""
|
| 99 |
+
Predict distribution parameters from a point embedding.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
embedding: Input embedding [batch, input_dim].
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
mu: Location parameter [batch, input_dim] (embedding + residual)
|
| 106 |
+
alpha: Scale parameter [batch, input_dim] (> 0, via softplus)
|
| 107 |
+
beta: Shape parameter [batch, input_dim] (> 0, via softplus)
|
| 108 |
+
"""
|
| 109 |
+
h = self.backbone(embedding)
|
| 110 |
+
|
| 111 |
+
# mu: residual + input (anchored to original embedding)
|
| 112 |
+
mu = embedding + self.mu_head(h)
|
| 113 |
+
|
| 114 |
+
# alpha, beta: positive via softplus
|
| 115 |
+
alpha = F.softplus(self.alpha_head(h)) + 1e-6
|
| 116 |
+
beta = F.softplus(self.beta_head(h)) + 1e-6
|
| 117 |
+
|
| 118 |
+
return mu, alpha, beta
|
| 119 |
+
|
| 120 |
+
def sample(
|
| 121 |
+
self,
|
| 122 |
+
embedding: np.ndarray,
|
| 123 |
+
n_samples: int = 100,
|
| 124 |
+
) -> np.ndarray:
|
| 125 |
+
"""
|
| 126 |
+
Draw Monte Carlo samples from the predicted distribution.
|
| 127 |
+
|
| 128 |
+
Uses the reparameterization trick for Generalized Gaussian:
|
| 129 |
+
x = mu + alpha * sign(u) * |u|^(1/beta)
|
| 130 |
+
where u ~ Uniform(-1, 1)
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
embedding: Input embedding, shape (dim,) or (1, dim).
|
| 134 |
+
n_samples: Number of MC samples.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Samples array, shape (n_samples, dim).
|
| 138 |
+
"""
|
| 139 |
+
_check_torch()
|
| 140 |
+
self.eval()
|
| 141 |
+
|
| 142 |
+
emb = embedding.squeeze()
|
| 143 |
+
if emb.ndim == 1:
|
| 144 |
+
emb = emb[np.newaxis, :]
|
| 145 |
+
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
x = torch.tensor(emb, dtype=torch.float32)
|
| 148 |
+
mu, alpha, beta = self.forward(x)
|
| 149 |
+
|
| 150 |
+
# Expand for sampling: [1, dim] -> [n_samples, dim]
|
| 151 |
+
mu = mu.expand(n_samples, -1)
|
| 152 |
+
alpha = alpha.expand(n_samples, -1)
|
| 153 |
+
beta = beta.expand(n_samples, -1)
|
| 154 |
+
|
| 155 |
+
# Reparameterized sampling from Generalized Gaussian
|
| 156 |
+
u = torch.rand_like(mu) * 2 - 1 # Uniform(-1, 1)
|
| 157 |
+
sign = torch.sign(u)
|
| 158 |
+
samples = mu + alpha * sign * (torch.abs(u) + 1e-8).pow(1.0 / beta)
|
| 159 |
+
|
| 160 |
+
# L2 normalize samples (stay on unit sphere)
|
| 161 |
+
samples = F.normalize(samples, p=2, dim=-1)
|
| 162 |
+
|
| 163 |
+
return samples.cpu().numpy()
|
| 164 |
+
|
| 165 |
+
def uncertainty(self, embedding: np.ndarray) -> float:
|
| 166 |
+
"""
|
| 167 |
+
Compute scalar aleatoric uncertainty for an embedding.
|
| 168 |
+
|
| 169 |
+
Returns the mean predicted alpha (scale parameter) across dimensions.
|
| 170 |
+
High alpha → high uncertainty → wide distribution.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
embedding: Input embedding, shape (dim,) or (1, dim).
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Scalar uncertainty value (mean alpha).
|
| 177 |
+
"""
|
| 178 |
+
_check_torch()
|
| 179 |
+
self.eval()
|
| 180 |
+
|
| 181 |
+
emb = embedding.squeeze()
|
| 182 |
+
if emb.ndim == 1:
|
| 183 |
+
emb = emb[np.newaxis, :]
|
| 184 |
+
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
x = torch.tensor(emb, dtype=torch.float32)
|
| 187 |
+
_, alpha, _ = self.forward(x)
|
| 188 |
+
return float(alpha.mean().item())
|
| 189 |
+
|
| 190 |
+
def save(self, path: str) -> None:
|
| 191 |
+
"""Save adapter weights + config."""
|
| 192 |
+
_check_torch()
|
| 193 |
+
import json
|
| 194 |
+
p = Path(path)
|
| 195 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 196 |
+
torch.save(self.state_dict(), p)
|
| 197 |
+
config_path = p.with_suffix(".json")
|
| 198 |
+
with config_path.open("w") as f:
|
| 199 |
+
json.dump(self.config, f, indent=2)
|
| 200 |
+
logger.info("Saved ProbabilisticAdapter to %s", path)
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def load(cls, path: str) -> "ProbabilisticAdapter":
|
| 204 |
+
"""Load adapter from saved weights."""
|
| 205 |
+
_check_torch()
|
| 206 |
+
import json
|
| 207 |
+
p = Path(path)
|
| 208 |
+
config_path = p.with_suffix(".json")
|
| 209 |
+
with config_path.open("r") as f:
|
| 210 |
+
config = json.load(f)
|
| 211 |
+
model = cls(**config)
|
| 212 |
+
state_dict = torch.load(p, map_location="cpu", weights_only=True)
|
| 213 |
+
model.load_state_dict(state_dict)
|
| 214 |
+
model.eval()
|
| 215 |
+
logger.info("Loaded ProbabilisticAdapter from %s", path)
|
| 216 |
+
return model
|
src/embeddings/space_alignment.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ex-MCR Cross-Space Alignment: CLAP Audio → CLIP Space.
|
| 3 |
+
|
| 4 |
+
Ex-MCR (Ex-Modal Contrastive Retrieval) projects CLAP audio embeddings
|
| 5 |
+
INTO CLIP space while keeping CLIP embeddings unchanged. This lets us
|
| 6 |
+
compute meaningful image-audio similarity and full 3-way Gramian volume.
|
| 7 |
+
|
| 8 |
+
Architecture decision: Ex-MCR over C-MCR because:
|
| 9 |
+
- Ex-MCR keeps CLIP embeddings frozen (no recomputation needed)
|
| 10 |
+
- C-MCR projects BOTH spaces into a new space (breaks everything)
|
| 11 |
+
|
| 12 |
+
The projector is a lightweight MLP:
|
| 13 |
+
CLAP 512-d → Linear(512, 512) → ReLU → Linear(512, 512) → L2 norm
|
| 14 |
+
|
| 15 |
+
If Ex-MCR weights are not available, falls back to an untrained identity
|
| 16 |
+
projection (which is equivalent to not using the projector).
|
| 17 |
+
|
| 18 |
+
CLAP compatibility note:
|
| 19 |
+
Our project uses `laion/clap-htsat-unfused`.
|
| 20 |
+
Ex-MCR uses `laion_clap_fullset_fusion` (different model).
|
| 21 |
+
If projections are poor with our CLAP, switch to the fusion model.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
TORCH_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
TORCH_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ExMCRProjector:
|
| 44 |
+
"""
|
| 45 |
+
Projects CLAP audio embeddings into CLIP space.
|
| 46 |
+
|
| 47 |
+
Usage:
|
| 48 |
+
proj = ExMCRProjector("models/exmcr/ex_clap.pt")
|
| 49 |
+
audio_in_clip = proj.project_audio(clap_embedding) # now comparable to CLIP
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
weights_path: Optional[str] = None,
|
| 55 |
+
device: str = "cpu",
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
weights_path: Path to Ex-MCR CLAP→CLIP projection weights (.pt).
|
| 60 |
+
If None or file doesn't exist, uses identity (passthrough).
|
| 61 |
+
device: Torch device for inference.
|
| 62 |
+
"""
|
| 63 |
+
self._model = None
|
| 64 |
+
self._device = device
|
| 65 |
+
self._identity_mode = True
|
| 66 |
+
|
| 67 |
+
if weights_path and Path(weights_path).exists() and TORCH_AVAILABLE:
|
| 68 |
+
self._load_weights(weights_path)
|
| 69 |
+
elif weights_path and not Path(weights_path).exists():
|
| 70 |
+
logger.warning(
|
| 71 |
+
"Ex-MCR weights not found: %s — using identity projection", weights_path
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _load_weights(self, path: str) -> None:
|
| 75 |
+
"""Load Ex-MCR projection head from saved weights."""
|
| 76 |
+
state_dict = torch.load(path, map_location=self._device, weights_only=True)
|
| 77 |
+
|
| 78 |
+
# Detect architecture from state dict keys
|
| 79 |
+
# Ex-MCR uses: layers.0.weight, layers.0.bias, layers.2.weight, layers.2.bias
|
| 80 |
+
# or: 0.weight, 0.bias, 2.weight, 2.bias
|
| 81 |
+
keys = list(state_dict.keys())
|
| 82 |
+
|
| 83 |
+
# Build matching MLP
|
| 84 |
+
if any("layers" in k for k in keys):
|
| 85 |
+
# Format: layers.0.weight etc.
|
| 86 |
+
in_dim = state_dict["layers.0.weight"].shape[1]
|
| 87 |
+
hidden_dim = state_dict["layers.0.weight"].shape[0]
|
| 88 |
+
out_dim = state_dict["layers.2.weight"].shape[0]
|
| 89 |
+
model = nn.Sequential(
|
| 90 |
+
nn.Linear(in_dim, hidden_dim),
|
| 91 |
+
nn.ReLU(),
|
| 92 |
+
nn.Linear(hidden_dim, out_dim),
|
| 93 |
+
)
|
| 94 |
+
# Rename keys to match sequential
|
| 95 |
+
new_state = {}
|
| 96 |
+
for k, v in state_dict.items():
|
| 97 |
+
new_key = k.replace("layers.", "")
|
| 98 |
+
new_state[new_key] = v
|
| 99 |
+
model.load_state_dict(new_state)
|
| 100 |
+
elif any(k.startswith("0.") for k in keys):
|
| 101 |
+
# Format: 0.weight, 0.bias, 2.weight, 2.bias (Sequential)
|
| 102 |
+
in_dim = state_dict["0.weight"].shape[1]
|
| 103 |
+
hidden_dim = state_dict["0.weight"].shape[0]
|
| 104 |
+
out_dim = state_dict["2.weight"].shape[0]
|
| 105 |
+
model = nn.Sequential(
|
| 106 |
+
nn.Linear(in_dim, hidden_dim),
|
| 107 |
+
nn.ReLU(),
|
| 108 |
+
nn.Linear(hidden_dim, out_dim),
|
| 109 |
+
)
|
| 110 |
+
model.load_state_dict(state_dict)
|
| 111 |
+
else:
|
| 112 |
+
# Generic: try to infer from weight shapes
|
| 113 |
+
weight_keys = [k for k in keys if "weight" in k]
|
| 114 |
+
if len(weight_keys) >= 2:
|
| 115 |
+
first_w = state_dict[weight_keys[0]]
|
| 116 |
+
last_w = state_dict[weight_keys[-1]]
|
| 117 |
+
in_dim = first_w.shape[1]
|
| 118 |
+
hidden_dim = first_w.shape[0]
|
| 119 |
+
out_dim = last_w.shape[0]
|
| 120 |
+
model = nn.Sequential(
|
| 121 |
+
nn.Linear(in_dim, hidden_dim),
|
| 122 |
+
nn.ReLU(),
|
| 123 |
+
nn.Linear(hidden_dim, out_dim),
|
| 124 |
+
)
|
| 125 |
+
model.load_state_dict(state_dict)
|
| 126 |
+
else:
|
| 127 |
+
logger.warning("Unrecognized Ex-MCR weight format — using identity")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
model.to(self._device)
|
| 131 |
+
model.eval()
|
| 132 |
+
self._model = model
|
| 133 |
+
self._identity_mode = False
|
| 134 |
+
logger.info(
|
| 135 |
+
"Ex-MCR projector loaded: %d → %d → %d (from %s)",
|
| 136 |
+
in_dim, hidden_dim, out_dim, path,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def is_identity(self) -> bool:
|
| 141 |
+
"""True if projector is passthrough (no trained weights loaded)."""
|
| 142 |
+
return self._identity_mode
|
| 143 |
+
|
| 144 |
+
def project_audio(self, clap_embedding: np.ndarray) -> np.ndarray:
|
| 145 |
+
"""
|
| 146 |
+
Project CLAP audio embedding into CLIP space.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
clap_embedding: CLAP audio embedding, shape (512,) or (N, 512).
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Projected embedding in CLIP space, L2-normalized.
|
| 153 |
+
"""
|
| 154 |
+
if self._identity_mode:
|
| 155 |
+
emb = clap_embedding.squeeze().astype(np.float32)
|
| 156 |
+
norm = np.linalg.norm(emb) + 1e-12
|
| 157 |
+
return emb / norm
|
| 158 |
+
|
| 159 |
+
if not TORCH_AVAILABLE:
|
| 160 |
+
return clap_embedding.squeeze().astype(np.float32)
|
| 161 |
+
|
| 162 |
+
was_1d = clap_embedding.ndim == 1 or (
|
| 163 |
+
clap_embedding.ndim == 2 and clap_embedding.shape[0] == 1
|
| 164 |
+
)
|
| 165 |
+
emb = clap_embedding.squeeze()
|
| 166 |
+
if emb.ndim == 1:
|
| 167 |
+
emb = emb[np.newaxis, :]
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
x = torch.tensor(emb, dtype=torch.float32, device=self._device)
|
| 171 |
+
projected = self._model(x)
|
| 172 |
+
projected = F.normalize(projected, p=2, dim=-1)
|
| 173 |
+
result = projected.cpu().numpy()
|
| 174 |
+
|
| 175 |
+
if was_1d:
|
| 176 |
+
return result.squeeze(0)
|
| 177 |
+
return result
|
| 178 |
+
|
| 179 |
+
def project_audio_batch(self, clap_embeddings: np.ndarray) -> np.ndarray:
|
| 180 |
+
"""
|
| 181 |
+
Batch projection of CLAP audio embeddings into CLIP space.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
clap_embeddings: Shape (N, 512).
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Projected embeddings in CLIP space, shape (N, 512), L2-normalized.
|
| 188 |
+
"""
|
| 189 |
+
if self._identity_mode:
|
| 190 |
+
norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12
|
| 191 |
+
return (clap_embeddings / norms).astype(np.float32)
|
| 192 |
+
|
| 193 |
+
if not TORCH_AVAILABLE:
|
| 194 |
+
norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12
|
| 195 |
+
return (clap_embeddings / norms).astype(np.float32)
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
x = torch.tensor(clap_embeddings, dtype=torch.float32, device=self._device)
|
| 199 |
+
projected = self._model(x)
|
| 200 |
+
projected = F.normalize(projected, p=2, dim=-1)
|
| 201 |
+
return projected.cpu().numpy()
|