atc-tts-mos / backend /session_manager.py
aether-raider
made minor changes
5df8bb5
# backend/session_manager.py
import json
import random
import time
import uuid
from typing import List, Dict, Any, Optional
from .models import Clip, get_display_model_name
class SessionManager:
"""Manages evaluation sessions, responses, and export logic."""
def __init__(self, data_manager):
self.data_manager = data_manager
self.sessions: Dict[str, Dict[str, Any]] = {}
self.responses: Dict[str, List[Dict[str, Any]]] = {
"mos": [],
"ab": [],
"feedback": [],
}
# --------------------------
# Session creation
# --------------------------
def create_session(self) -> Dict[str, Any]:
session_id = str(uuid.uuid4())
clips = self.data_manager.load_clips()
rng = random.Random(time.time())
mos_clips: List[Clip] = []
models = {clip.model for clip in clips}
# Build MOS clip set
for model in models:
model_clips = [clip for clip in clips if clip.model == model]
exercise_groups: Dict[str, Dict[str, List[Clip]]] = {}
for clip in model_clips:
if clip.exercise_id not in exercise_groups:
exercise_groups[clip.exercise_id] = {"male": [], "female": []}
exercise_groups[clip.exercise_id][clip.speaker].append(clip)
# Collect all clips from this model
all_model_clips = []
for _, speakers in exercise_groups.items():
if speakers["male"]:
all_model_clips.extend(speakers["male"])
if speakers["female"]:
all_model_clips.extend(speakers["female"])
# Select 3 random clips (regardless of gender pairing) for this model
selected_clips = rng.sample(all_model_clips, min(3, len(all_model_clips)))
mos_clips.extend(selected_clips)
# Group by content (exercise + transcript) for comparisons
content_groups: Dict[Any, List[Clip]] = {}
for clip in clips:
key = (clip.exercise, clip.exercise_id, clip.transcript)
content_groups.setdefault(key, []).append(clip)
# --- Model vs Model (same gender, same exercise) ---
ab_model_pairs = []
# Get all unique exercises
all_exercises = list({key[1] for key in content_groups.keys()})
rng.shuffle(all_exercises)
max_pairs = 6
for exercise_id in all_exercises:
# Find all content groups for this exercise
matching_keys = [k for k in content_groups if k[1] == exercise_id]
if not matching_keys:
continue
# Pick a random content group for this exercise
key = rng.choice(matching_keys)
group = content_groups[key]
# Group by model and speaker
model_speaker_map: Dict[str, Dict[str, List[Clip]]] = {}
for clip in group:
model_speaker_map.setdefault(clip.model, {}).setdefault(clip.speaker, []).append(clip)
model_names = list(model_speaker_map.keys())
if len(model_names) < 2:
continue
# Try to find a random valid gender for this exercise
valid_genders = [s for s in ["male", "female"] if sum(1 for m in model_names if s in model_speaker_map[m] and model_speaker_map[m][s]) >= 2]
if not valid_genders:
continue
speaker = rng.choice(valid_genders)
available_models = [model for model in model_names if speaker in model_speaker_map[model] and model_speaker_map[model][speaker]]
if len(available_models) < 2:
continue
model_a, model_b = rng.sample(available_models, 2)
clip_a = rng.choice(model_speaker_map[model_a][speaker])
clip_b = rng.choice(model_speaker_map[model_b][speaker])
ab_model_pairs.append((clip_a, clip_b))
if len(ab_model_pairs) >= max_pairs:
break
# --- Gender vs Gender (same model, same exercise) ---
ab_gender_pairs = []
rng.shuffle(all_exercises)
for exercise_id in all_exercises:
matching_keys = [k for k in content_groups if k[1] == exercise_id]
if not matching_keys:
continue
key = rng.choice(matching_keys)
group = content_groups[key]
# Group by model and gender
model_gender_groups: Dict[str, Dict[str, List[Clip]]] = {}
for clip in group:
model_gender_groups.setdefault(clip.model, {}).setdefault(clip.speaker, []).append(clip)
valid_models = [m for m, genders in model_gender_groups.items() if "male" in genders and "female" in genders and genders["male"] and genders["female"]]
if not valid_models:
continue
model = rng.choice(valid_models)
gender_groups = model_gender_groups[model]
clip_male = rng.choice(gender_groups["male"])
clip_female = rng.choice(gender_groups["female"])
ab_gender_pairs.append((clip_male, clip_female))
if len(ab_gender_pairs) >= max_pairs:
break
session_data: Dict[str, Any] = {
"session_id": session_id,
"created_at": time.time(),
"mos_clips": mos_clips,
"ab_model_pairs": ab_model_pairs,
"ab_gender_pairs": ab_gender_pairs,
"completed": False,
}
self.sessions[session_id] = session_data
return session_data
# --------------------------
# Response storage helpers
# --------------------------
def save_response(self, response_type: str, response: Dict[str, Any]):
"""Generic low-level append with auto-timestamp."""
if "timestamp" not in response:
response["timestamp"] = time.time()
self.responses.setdefault(response_type, []).append(response)
def save_mos_rating(
self,
session: Dict[str, Any],
clip_id: str,
model: str,
clarity: Optional[int],
pronunciation: Optional[int],
prosody: Optional[int],
naturalness: Optional[int],
overall: Optional[int],
comment: str,
gender_mismatch: bool,
) -> None:
"""Optional helper for saving a single MOS rating."""
if not session:
return
mos_response = {
"session_id": session["session_id"],
"clip_id": clip_id,
"clarity": int(clarity) if clarity is not None else None,
"pronunciation": int(pronunciation) if pronunciation is not None else None,
"prosody": int(prosody) if prosody is not None else None,
"naturalness": int(naturalness) if naturalness is not None else None,
"overall": int(overall) if overall is not None else None,
"comment": comment or "",
"gender_mismatch": bool(gender_mismatch),
"timestamp": time.time(),
}
self.save_response("mos", mos_response)
def save_ab_rating(
self,
session: Dict[str, Any],
clip_a_id: str,
clip_b_id: str,
comparison_type: str,
choice: str,
comment: str,
gender_mismatch_a: bool,
gender_mismatch_b: bool,
) -> None:
"""Optional helper for saving a single A/B comparison."""
if not session:
return
ab_response = {
"session_id": session["session_id"],
"clip_a_id": clip_a_id,
"clip_b_id": clip_b_id,
"comparison_type": comparison_type,
"choice": choice,
"comment": comment or "",
"gender_mismatch_a": bool(gender_mismatch_a),
"gender_mismatch_b": bool(gender_mismatch_b),
"timestamp": time.time(),
}
self.save_response("ab", ab_response)
# --------------------------
# Bulk processing from JS JSON
# --------------------------
def process_mos_data(
self,
session: Dict[str, Any],
mos_data_json: str,
) -> None:
"""
Take the JSON string from the hidden MOS textbox and turn it into
individual MOS responses in self.responses["mos"].
"""
print(f"[DEBUG] process_mos_data called with JSON: '{mos_data_json}'")
print(f"[DEBUG] Session ID: {session.get('session_id') if session else 'None'}")
if not session or not mos_data_json:
print(f"[DEBUG] Skipping MOS processing - session: {session is not None}, data length: {len(mos_data_json) if mos_data_json else 0}")
return
try:
ratings_data = json.loads(mos_data_json) if mos_data_json else {}
except json.JSONDecodeError as e:
print(f"[WARN] Failed to parse MOS data JSON: {e}")
return
try:
# Get all clips that were presented to the user
presented_clips = session.get("mos_clips", [])
presented_clip_ids = {clip.id for clip in presented_clips}
print(f"[DEBUG] Presented {len(presented_clip_ids)} MOS clips to user")
print(f"[DEBUG] Received ratings for {len(ratings_data)} clips")
# Process all presented clips, whether rated or not
for clip in presented_clips:
clip_id = clip.id
ratings = ratings_data.get(clip_id, {})
mos_response = {
"session_id": session["session_id"],
"clip_id": clip_id,
"clarity": int(ratings.get("clarity"))
if ratings.get("clarity")
else None,
"pronunciation": int(ratings.get("pronunciation"))
if ratings.get("pronunciation")
else None,
"prosody": int(ratings.get("prosody"))
if ratings.get("prosody")
else None,
"naturalness": int(ratings.get("naturalness"))
if ratings.get("naturalness")
else None,
"overall": int(ratings.get("overall"))
if ratings.get("overall")
else None,
"comment": ratings.get("comment", ""),
"gender_mismatch": ratings.get("gender_mismatch", False),
"timestamp": time.time(),
}
self.save_response("mos", mos_response)
# Log whether this clip was rated or not
has_ratings = any(
ratings.get(dim)
for dim in ["clarity", "pronunciation", "prosody", "naturalness", "overall"]
)
status = "rated" if has_ratings else "not rated"
print(f"[INFO] Processed MOS clip {clip_id} ({status})")
except Exception as e:
print(f"[WARN] Error processing MOS data: {e}")
def process_ab_data(
self,
session: Dict[str, Any],
ab_data_json: str,
) -> None:
"""
Take the JSON string from the hidden AB textbox and turn it into
individual A/B responses in self.responses["ab"].
"""
print(f"[DEBUG] process_ab_data called with JSON: '{ab_data_json}'")
print(f"[DEBUG] Session ID: {session.get('session_id') if session else 'None'}")
if not session or not ab_data_json:
print(f"[DEBUG] Skipping AB processing - session: {session is not None}, data length: {len(ab_data_json) if ab_data_json else 0}")
return
try:
comparisons_data = json.loads(ab_data_json) if ab_data_json else {}
except json.JSONDecodeError as e:
print(f"[WARN] Failed to parse A/B data JSON: {e}")
return
try:
print(f"[DEBUG] Received ratings for {len(comparisons_data)} comparisons")
# Build a set of already saved comparison IDs to avoid duplicates
session_id = session["session_id"]
existing_ab_responses = [
r for r in self.responses.get("ab", [])
if r.get("session_id") == session_id
]
existing_pairs = {
(r["clip_a_id"], r["clip_b_id"])
for r in existing_ab_responses
}
print(f"[DEBUG] Already have {len(existing_pairs)} AB comparisons saved for this session")
# Collect ALL comparison types present in the data
comparison_types = set()
for comp_data in comparisons_data.values():
if comp_data.get("comparison_type"):
comparison_types.add(comp_data["comparison_type"])
print(f"[DEBUG] Found comparison types in data: {comparison_types}")
# Process each comparison type separately
all_presented_pairs = []
for comparison_type in comparison_types:
if comparison_type == "model_vs_model":
pairs = session.get("ab_model_pairs", [])
print(f"[DEBUG] Processing model-vs-model pairs: {len(pairs)} pairs presented")
all_presented_pairs.extend([(clip_a, clip_b, comparison_type) for clip_a, clip_b in pairs])
elif comparison_type == "gender_vs_gender":
pairs = session.get("ab_gender_pairs", [])
print(f"[DEBUG] Processing gender-vs-gender pairs: {len(pairs)} pairs presented")
all_presented_pairs.extend([(clip_a, clip_b, comparison_type) for clip_a, clip_b in pairs])
else:
print(f"[WARN] Unknown comparison type: {comparison_type}")
# Process all presented pairs
for clip_a, clip_b, comparison_type in all_presented_pairs:
clip_a_id = clip_a.id
clip_b_id = clip_b.id
# Skip if we've already saved this pair
if (clip_a_id, clip_b_id) in existing_pairs:
print(f"[DEBUG] Skipping duplicate comparison: {clip_a_id} vs {clip_b_id}")
continue
# Find user's rating for this pair from the submitted data
# JS sends numeric keys ("1", "2", etc.), so search by clip IDs
comparison = {}
for comp_data in comparisons_data.values():
if comp_data.get("clip_a_id") == clip_a_id and comp_data.get("clip_b_id") == clip_b_id:
comparison = comp_data
break
ab_response = {
"session_id": session_id,
"clip_a_id": clip_a_id,
"clip_b_id": clip_b_id,
"comparison_type": comparison_type,
"choice": comparison.get("choice"), # Can be None if not rated
"comment": comparison.get("comment", ""),
# Support both model_vs_model (gender_mismatch_a/b)
# and gender_vs_gender (gender_mismatch_male/female)
"gender_mismatch_a": comparison.get("gender_mismatch_a", False)
or comparison.get("gender_mismatch_male", False),
"gender_mismatch_b": comparison.get("gender_mismatch_b", False)
or comparison.get("gender_mismatch_female", False),
"timestamp": time.time(),
}
self.save_response("ab", ab_response)
existing_pairs.add((clip_a_id, clip_b_id)) # Mark as saved
status = "rated" if comparison.get("choice") else "not rated"
print(f"[INFO] Processed A/B comparison {clip_a_id} vs {clip_b_id} ({status})")
except Exception as e:
print(f"[WARN] Error processing A/B data: {e}")
# --------------------------
# Export
# --------------------------
def export_session(self, session_id: str) -> Dict[str, Any]:
"""Build a fully annotated export dict for a given session."""
session = self.sessions.get(session_id, {})
# Create detailed MOS responses with full clip metadata
detailed_mos_responses = []
session_mos_clips = {clip.id: clip for clip in session.get("mos_clips", [])}
for r in self.responses.get("mos", []):
if r.get("session_id") != session_id:
continue
clip_id = r.get("clip_id")
clip = session_mos_clips.get(clip_id)
if not clip:
continue
detailed_response = {
# Session metadata
"session_id": session_id,
"response_timestamp": r.get("timestamp", time.time()),
# Full clip metadata
"clip_id": clip_id,
"exercise": clip.exercise,
"exercise_id": clip.exercise_id,
"transcript": clip.transcript,
"model": clip.model, # Original model name
"display_model": get_display_model_name(
clip.model
), # Anonymized name
"speaker": clip.speaker,
# MOS ratings
"clarity": r.get("clarity"),
"pronunciation": r.get("pronunciation"),
"prosody": r.get("prosody"),
"naturalness": r.get("naturalness"),
"overall": r.get("overall"),
"comment": r.get("comment", ""),
# Quality control flags
"gender_mismatch": r.get(
"gender_mismatch", False
), # True if user flagged wrong gender
# Response type
"evaluation_type": "mos_rating",
}
detailed_mos_responses.append(detailed_response)
# Create detailed A/B responses with full clip metadata
detailed_ab_responses = []
session_ab_model_pairs = session.get("ab_model_pairs", [])
session_ab_gender_pairs = session.get("ab_gender_pairs", [])
for r in self.responses.get("ab", []):
if r.get("session_id") != session_id:
continue
clip_a_id = r.get("clip_a_id")
clip_b_id = r.get("clip_b_id")
comparison_type = r.get("comparison_type")
# Find the clips from session pairs
clip_a, clip_b = None, None
if comparison_type == "model_vs_model":
for pair_a, pair_b in session_ab_model_pairs:
if pair_a.id == clip_a_id and pair_b.id == clip_b_id:
clip_a, clip_b = pair_a, pair_b
break
elif comparison_type == "gender_vs_gender":
for pair_a, pair_b in session_ab_gender_pairs:
if pair_a.id == clip_a_id and pair_b.id == clip_b_id:
clip_a, clip_b = pair_a, pair_b
break
if not (clip_a and clip_b):
continue
detailed_response = {
# Session metadata
"session_id": session_id,
"response_timestamp": r.get("timestamp", time.time()),
# Comparison metadata
"comparison_type": comparison_type,
"choice": r.get("choice"),
"comment": r.get("comment", ""),
# Clip A metadata
"clip_a_id": clip_a.id,
"clip_a_exercise": clip_a.exercise,
"clip_a_exercise_id": clip_a.exercise_id,
"clip_a_transcript": clip_a.transcript,
"clip_a_model": clip_a.model,
"clip_a_display_model": get_display_model_name(clip_a.model),
"clip_a_speaker": clip_a.speaker,
# Clip B metadata
"clip_b_id": clip_b.id,
"clip_b_exercise": clip_b.exercise,
"clip_b_exercise_id": clip_b.exercise_id,
"clip_b_transcript": clip_b.transcript,
"clip_b_model": clip_b.model,
"clip_b_display_model": get_display_model_name(clip_b.model),
"clip_b_speaker": clip_b.speaker,
# Quality control flags
"gender_mismatch_a": r.get(
"gender_mismatch_a", False
), # True if clip A has wrong gender
"gender_mismatch_b": r.get(
"gender_mismatch_b", False
), # True if clip B has wrong gender
# Response type
"evaluation_type": "ab_comparison",
}
detailed_ab_responses.append(detailed_response)
return {
"session_metadata": {
"session_id": session_id,
"created_at": session.get("created_at"),
"completed": session.get("completed", False),
"exported_at": time.time(),
"total_mos_ratings": len(detailed_mos_responses),
"total_ab_comparisons": len(detailed_ab_responses),
},
"mos_ratings": detailed_mos_responses,
"ab_comparisons": detailed_ab_responses,
"overall_feedback": [
r
for r in self.responses.get("feedback", [])
if r.get("session_id") == session_id
],
}