atc-tts-mos / backend /models.py
aether-raider
testing sample audios fix
ef76c0c
# backend/models.py
from dataclasses import dataclass
from typing import Any, Optional
# Model name mapping for anonymization
# Maps internal model names to display labels
MODEL_NAME_MAP = {
"xtts": "Model A", # Coqui XTTS -> Model A
"csm": "Model B", # Sesame CSM -> Model B
"orpheus": "Model C", # Orpheus -> Model C
}
def get_display_model_name(internal_name: str) -> str:
"""Convert internal model name to display label."""
return MODEL_NAME_MAP.get(internal_name, internal_name.upper())
def audio_to_base64_url(audio_data):
"""Convert audio data to base64 URL for HTML audio elements."""
if isinstance(audio_data, str):
if audio_data.startswith("data:audio/"):
return audio_data
elif audio_data.endswith(('.wav', '.mp3', '.flac', '.ogg')):
# Handle file path from LFS - convert to base64
try:
import base64
import os
if os.path.exists(audio_data):
with open(audio_data, "rb") as f:
audio_bytes = f.read()
b64 = base64.b64encode(audio_bytes).decode("ascii")
return f"data:audio/wav;base64,{b64}"
except Exception as e:
print(f"[WARN] Failed to convert file to base64 URL: {e}")
elif isinstance(audio_data, tuple) and len(audio_data) == 2:
# Convert (array, sample_rate) tuple to base64 URL
try:
import numpy as np
import base64
import io
try:
import soundfile as sf
except ImportError:
return None
array, sr = audio_data
if sf is not None:
buf = io.BytesIO()
sf.write(buf, np.array(array), int(sr), format="WAV")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:audio/wav;base64,{b64}"
except Exception as e:
print(f"[WARN] Failed to convert audio tuple to base64 URL: {e}")
return None
# Data models
@dataclass
class Clip:
id: str
model: str
speaker: str # male/female
exercise: str
exercise_id: str
transcript: str
audio_url: Any # Can be string URL or tuple (array, sample_rate)
duration_s: Optional[float] = None
@dataclass
class MOSResponse:
session_id: str
clip_id: str
clarity: int
pronunciation: int
prosody: int
naturalness: int
overall: int
comment: str = ""
gender_mismatch: bool = False # Flag for wrong gender voice
@dataclass
class ABResponse:
session_id: str
clip_a_id: str
clip_b_id: str
comparison_type: str # "model_vs_model" or "gender_vs_gender"
choice: str # "A", "B", "tie"
comment: str = ""
gender_mismatch_a: bool = False # Flag for wrong gender voice in clip A
gender_mismatch_b: bool = False # Flag for wrong gender voice in clip B