ustwo-api / src /stage4 /main.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
25.5 kB
import json
import logging
import random
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sqlalchemy.orm import Session
from .database import init_db, get_db, SessionLocal
from . import models
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
app = FastAPI(title="UsTwo API", version="0.3.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = Path("data/samples")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
ALLOWED_EXTENSIONS = {".wav", ".m4a", ".mp3", ".ogg"}
MAX_UPLOAD_BYTES = 50 * 1024 * 1024 # 50 MB
@app.on_event("startup")
def on_startup():
init_db()
_seed_demo_calls()
def _seed_demo_calls():
"""Pre-create call records for demo test scenarios (data/samples/*.json)."""
db = next(get_db())
demo_ids = ["test_happy", "test_tense", "test_mixed", "test_neutral"]
for call_id in demo_ids:
existing = db.query(models.Call).filter_by(id=call_id).first()
if not existing:
db.add(models.Call(id=call_id, audio_path="", status="uploaded"))
db.commit()
# ─── Health ─────────────────────────────────────────────────
@app.get("/api/health")
def health():
return {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()}
# ─── Upload ─────────────────────────────────────────────────
@app.post("/api/upload")
async def upload_audio(file: UploadFile = File(...), db: Session = Depends(get_db)):
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
ext = Path(file.filename).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}. Allowed: {ALLOWED_EXTENSIONS}",
)
content = await file.read()
if len(content) == 0:
raise HTTPException(status_code=400, detail="Empty file")
if len(content) > MAX_UPLOAD_BYTES:
raise HTTPException(
status_code=400,
detail=f"File too large ({len(content) // (1024*1024)}MB). Max: {MAX_UPLOAD_BYTES // (1024*1024)}MB",
)
call_id = f"call_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
save_path = UPLOAD_DIR / f"{call_id}{ext}"
save_path.write_bytes(content)
# Save to DB
db.add(models.Call(id=call_id, audio_path=str(save_path)))
db.commit()
return {"status": "success", "call_id": call_id, "filename": save_path.name}
# ─── Pipeline runner ─────────────────────────────────────────
def _run_full_pipeline(audio_path: str, call_id: str):
"""Run Stage 1 β†’ 2 β†’ 3 on a real audio file.
Returns (stage3_output, stage1_output, stage2_output) tuple.
Raises RuntimeError if ML dependencies are missing.
"""
from src.stage1.process import process as stage1_process
from src.stage2.process import process as stage2_process
from src.stage3.process import process as stage3_process
logger.info("Pipeline start: %s (%s)", call_id, audio_path)
# Stage 1: Diarization + ASR
stage1_out = stage1_process(audio_path)
logger.info(
"Stage 1 done: %d segments, %.1fs",
len(stage1_out.segments),
stage1_out.processing_info.processing_time_sec,
)
# Stage 2: Emotion analysis
stage2_out = stage2_process(stage1_out)
logger.info(
"Stage 2 done: %d emotions, speakers=%s",
len(stage2_out.emotions),
list(stage2_out.speaker_summaries.keys()),
)
# Stage 3: Character + Garden + Recap
use_llm = False # rule-based fallback for now
stage3_out = stage3_process(
stage2_out,
segments=stage1_out.segments,
use_llm=use_llm,
)
logger.info("Stage 3 done: %s", call_id)
return stage3_out, stage1_out, stage2_out
_MOCK_SCENARIOS = [
# (sp0_emotion, sp1_emotion) β€” cycles through for consistent demo results
("joy", "joy"),
("joy", "sadness"),
("anger", "sadness"),
("neutral", "neutral"),
("surprise", "joy"),
]
_mock_index = 0
def _generate_mock_stage2(call_id: str):
"""Mock Stage 2 output β€” deterministic scenarios with rich segment data for Landscape."""
global _mock_index
from src.common.schemas import Stage2Output, EmotionResult, SpeakerSummary
sp0_emotion, sp1_emotion = _MOCK_SCENARIOS[_mock_index % len(_MOCK_SCENARIOS)]
_mock_index += 1
def _fixed_summary(emotion: str) -> SpeakerSummary:
return SpeakerSummary(
dominant_emotion=emotion,
emotion_distribution={emotion: 0.70, "neutral": 0.30},
avg_confidence=0.82,
)
sp0_summary = _fixed_summary(sp0_emotion)
sp1_summary = _fixed_summary(sp1_emotion)
# Generate 6-8 alternating segments for a visually rich Emotional Landscape
_SEGMENT_PATTERNS = {
("joy", "joy"): [
("speaker_0", "joy", 0.85), ("speaker_1", "joy", 0.78),
("speaker_0", "surprise", 0.65), ("speaker_1", "joy", 0.82),
("speaker_0", "joy", 0.88), ("speaker_1", "neutral", 0.70),
("speaker_0", "joy", 0.80), ("speaker_1", "joy", 0.75),
],
("joy", "sadness"): [
("speaker_0", "joy", 0.80), ("speaker_1", "neutral", 0.72),
("speaker_0", "joy", 0.75), ("speaker_1", "sadness", 0.68),
("speaker_0", "neutral", 0.70), ("speaker_1", "sadness", 0.78),
("speaker_0", "joy", 0.82), ("speaker_1", "sadness", 0.65),
],
("anger", "sadness"): [
("speaker_0", "neutral", 0.72), ("speaker_1", "neutral", 0.70),
("speaker_0", "anger", 0.78), ("speaker_1", "sadness", 0.74),
("speaker_0", "anger", 0.82), ("speaker_1", "fear", 0.65),
("speaker_0", "neutral", 0.68), ("speaker_1", "sadness", 0.80),
("speaker_0", "anger", 0.75), ("speaker_1", "sadness", 0.72),
],
("neutral", "neutral"): [
("speaker_0", "neutral", 0.85), ("speaker_1", "neutral", 0.82),
("speaker_0", "neutral", 0.78), ("speaker_1", "joy", 0.60),
("speaker_0", "neutral", 0.80), ("speaker_1", "neutral", 0.75),
("speaker_0", "joy", 0.62), ("speaker_1", "neutral", 0.80),
],
("surprise", "joy"): [
("speaker_0", "neutral", 0.72), ("speaker_1", "joy", 0.75),
("speaker_0", "surprise", 0.80), ("speaker_1", "joy", 0.82),
("speaker_0", "surprise", 0.85), ("speaker_1", "surprise", 0.70),
("speaker_0", "joy", 0.78), ("speaker_1", "joy", 0.80),
],
}
pattern = _SEGMENT_PATTERNS.get((sp0_emotion, sp1_emotion), _SEGMENT_PATTERNS[("neutral", "neutral")])
emotions = []
for i, (spk, emo, conf) in enumerate(pattern):
emotions.append(EmotionResult(
speaker_id=spk, segment_id=i,
audio_emotion=emo, audio_confidence=conf,
text_emotion=emo, text_confidence=max(0.5, conf - 0.1),
fused_emotion=emo, fused_confidence=conf,
))
return Stage2Output(
call_id=call_id,
emotions=emotions,
speaker_summaries={
"speaker_0": sp0_summary,
"speaker_1": sp1_summary,
},
)
# ─── Background pipeline worker ──────────────────────────────
def _quick_recap(call_id: str, audio_path: str, db: Session):
"""Phase 1: Fast recap via Whisper API + Claude (~15s).
Stores partial result with status 'preview'. Requires OPENAI_API_KEY.
"""
import os
openai_key = os.environ.get("OPENAI_API_KEY")
if not openai_key:
logger.info("OPENAI_API_KEY not set, skipping quick recap for %s", call_id)
return
from openai import OpenAI
from src.stage3.recap_generator import generate_recap_from_transcript
logger.info("Quick recap start: %s", call_id)
# 1. Whisper API transcription
openai_client = OpenAI(api_key=openai_key)
with open(audio_path, "rb") as f:
transcription = openai_client.audio.transcriptions.create(
model="whisper-1",
file=f,
)
transcript = transcription.text
logger.info("Whisper API done: %d chars", len(transcript))
# 2. Claude recap from transcript
recap_card = generate_recap_from_transcript(transcript)
# 3. Build partial result with default character reactions
partial_result = {
"call_id": call_id,
"character_reactions": [
{"speaker_id": "speaker_0", "solo_state": "neutral", "pair_state": "sitting_together"},
{"speaker_id": "speaker_1", "solo_state": "neutral", "pair_state": "sitting_together"},
],
"garden_update": {"growth_delta": 0, "total_level": 1, "mood": "happy"},
"recap_card": recap_card.model_dump(),
}
# 4. Store in DB
db.add(models.AnalysisResult(
call_id=call_id,
stage3_json=json.dumps(partial_result),
blue_emotion="neutral",
pink_emotion="neutral",
garden_delta=0,
))
call = db.query(models.Call).filter(models.Call.id == call_id).first()
if call:
call.status = "preview"
db.commit()
logger.info("Quick recap done: %s β€” '%s'", call_id, recap_card.title)
def _run_pipeline_background(call_id: str, audio_path: str):
"""Run 2-phase pipeline in a background thread.
Phase 1: Quick recap via Whisper API + Claude (~15s) β†’ status 'preview'
Phase 2: Full ML pipeline (diarization + emotion) β†’ status 'done'
Phase 2 preserves Phase 1 recap, only updates character reactions + garden.
"""
db = SessionLocal()
try:
call = db.query(models.Call).filter(models.Call.id == call_id).first()
if not call:
return
call.status = "analyzing"
db.commit()
# Phase 1: Quick Recap (best-effort, failure doesn't block Phase 2)
has_preview = False
try:
_quick_recap(call_id, audio_path, db)
has_preview = True
except Exception as e:
logger.warning("Quick recap failed for %s: %s", call_id, e)
# Phase 2: Full ML Pipeline
try:
import sys
print(f"[PIPELINE] Starting full pipeline for {call_id}", file=sys.stderr, flush=True)
stage3_result, stage1_out, stage2_out = _run_full_pipeline(audio_path, call_id)
pipeline_mode = "full"
call = db.query(models.Call).filter(models.Call.id == call_id).first()
call.duration = stage1_out.duration
except ImportError as e:
logger.warning("ML deps missing (%s), falling back to mock", e)
from src.stage3.process import process as stage3_process
stage2_out = _generate_mock_stage2(call_id)
stage3_result = stage3_process(stage2_out, use_llm=False)
pipeline_mode = "mock"
call = db.query(models.Call).filter(models.Call.id == call_id).first()
result_dict = stage3_result.model_dump()
result_dict["emotions"] = [e.model_dump() for e in stage2_out.emotions]
result_dict["stage2_output"] = stage2_out.model_dump()
from src.stage3.character_mapping import select_representative_emotion
summaries = stage2_out.speaker_summaries or {}
blue_emo = select_representative_emotion(summaries["speaker_0"]) if "speaker_0" in summaries else None
pink_emo = select_representative_emotion(summaries["speaker_1"]) if "speaker_1" in summaries else None
if has_preview:
# Preserve Phase 1 recap, update only character reactions + garden
existing = db.query(models.AnalysisResult).filter(
models.AnalysisResult.call_id == call_id
).first()
if existing:
preview_data = json.loads(existing.stage3_json)
result_dict["recap_card"] = preview_data["recap_card"]
existing.stage3_json = json.dumps(result_dict)
existing.blue_emotion = blue_emo
existing.pink_emotion = pink_emo
existing.garden_delta = result_dict.get("garden_update", {}).get("growth_delta", 0)
else:
db.add(models.AnalysisResult(
call_id=call_id,
stage3_json=json.dumps(result_dict),
blue_emotion=blue_emo,
pink_emotion=pink_emo,
garden_delta=result_dict.get("garden_update", {}).get("growth_delta", 0),
))
else:
db.add(models.AnalysisResult(
call_id=call_id,
stage3_json=json.dumps(result_dict),
blue_emotion=blue_emo,
pink_emotion=pink_emo,
garden_delta=result_dict.get("garden_update", {}).get("growth_delta", 0),
))
_update_garden(db, result_dict)
call.status = "done"
db.commit()
logger.info("Background pipeline done: %s (mode=%s, preview=%s)", call_id, pipeline_mode, has_preview)
except Exception as e:
logger.error("Background pipeline failed for %s: %s", call_id, e, exc_info=True)
db.rollback()
try:
call = db.query(models.Call).filter(models.Call.id == call_id).first()
if call:
call.status = "error"
call.error_message = str(e)
db.commit()
except Exception:
db.rollback()
finally:
db.close()
# ─── Analyze ────────────────────────────────────────────────
@app.post("/api/analyze", status_code=202)
def analyze(call_id: str, db: Session = Depends(get_db)):
"""Start async analysis pipeline. Returns immediately with 202.
Poll GET /api/analyze/{call_id}/status for progress.
When status is 'done', result is available at GET /api/calls/{call_id}.
"""
call = db.query(models.Call).filter(models.Call.id == call_id).first()
if not call:
raise HTTPException(status_code=404, detail="Call not found")
if call.status in ("analyzing", "preview"):
return {"status": "analyzing", "call_id": call_id, "message": "Already in progress"}
if call.status == "done":
# Already analyzed β€” return existing result
existing = db.query(models.AnalysisResult).filter(
models.AnalysisResult.call_id == call_id
).first()
if existing:
return {
"status": "done",
"call_id": call_id,
"result": json.loads(existing.stage3_json),
}
audio_path = call.audio_path
# Check for pre-computed Stage 2 JSON (승재's pipeline output or test data)
stage2_json_path = Path("data") / f"{call_id}_stage2.json"
if not stage2_json_path.exists():
stage2_json_path = Path("data/samples") / f"{call_id}_stage2.json"
if stage2_json_path.exists():
# Use real Stage 2 output β†’ run Stage 3 only
from src.common.schemas import Stage2Output
from src.stage3.process import process as stage3_process
try:
stage2 = Stage2Output.model_validate_json(stage2_json_path.read_text())
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid Stage 2 JSON: {e}")
import os
has_api_key = bool(os.environ.get("ANTHROPIC_API_KEY"))
stage3_result = stage3_process(stage2, use_llm=has_api_key)
result_dict = stage3_result.model_dump()
result_dict["emotions"] = [e.model_dump() for e in stage2.emotions]
from src.stage3.character_mapping import select_representative_emotion
summaries = stage2.speaker_summaries or {}
blue_emo = select_representative_emotion(summaries["speaker_0"]) if "speaker_0" in summaries else None
pink_emo = select_representative_emotion(summaries["speaker_1"]) if "speaker_1" in summaries else None
db.add(models.AnalysisResult(
call_id=call_id,
stage3_json=json.dumps(result_dict),
blue_emotion=blue_emo,
pink_emotion=pink_emo,
garden_delta=result_dict.get("garden_update", {}).get("growth_delta", 0),
))
_update_garden(db, result_dict)
call.status = "done"
db.commit()
return {
"status": "done",
"call_id": call_id,
"pipeline_mode": "stage2_json",
"result": result_dict,
}
if not audio_path or not Path(audio_path).exists():
# No audio file and no stage2 JSON β€” run mock synchronously
from src.stage3.process import process as stage3_process
stage2 = _generate_mock_stage2(call_id)
import os
has_api_key = bool(os.environ.get("ANTHROPIC_API_KEY"))
stage3_result = stage3_process(stage2, use_llm=has_api_key)
result_dict = stage3_result.model_dump()
result_dict["emotions"] = [e.model_dump() for e in stage2.emotions]
from src.stage3.character_mapping import select_representative_emotion
summaries = stage2.speaker_summaries or {}
blue_emo = select_representative_emotion(summaries["speaker_0"]) if "speaker_0" in summaries else None
pink_emo = select_representative_emotion(summaries["speaker_1"]) if "speaker_1" in summaries else None
db.add(models.AnalysisResult(
call_id=call_id,
stage3_json=json.dumps(result_dict),
blue_emotion=blue_emo,
pink_emotion=pink_emo,
garden_delta=result_dict.get("garden_update", {}).get("growth_delta", 0),
))
_update_garden(db, result_dict)
call.status = "done"
db.commit()
return {
"status": "done",
"call_id": call_id,
"pipeline_mode": "mock",
"result": result_dict,
}
# Launch background pipeline
thread = threading.Thread(
target=_run_pipeline_background,
args=(call_id, audio_path),
daemon=True,
)
thread.start()
return {"status": "analyzing", "call_id": call_id, "message": "Pipeline started"}
@app.get("/api/analyze/{call_id}/status")
def analyze_status(call_id: str, db: Session = Depends(get_db)):
"""Poll analysis progress."""
call = db.query(models.Call).filter(models.Call.id == call_id).first()
if not call:
raise HTTPException(status_code=404, detail="Call not found")
if call.status in ("done", "preview"):
result = db.query(models.AnalysisResult).filter(
models.AnalysisResult.call_id == call_id
).first()
return {
"status": call.status,
"call_id": call_id,
"result": json.loads(result.stage3_json) if result else None,
}
if call.status == "error":
return {
"status": "error",
"call_id": call_id,
"error": call.error_message,
}
return {"status": call.status, "call_id": call_id}
# ─── Calls (history) ────────────────────────────────────────
@app.get("/api/calls")
def list_calls(db: Session = Depends(get_db)):
results = (
db.query(models.AnalysisResult, models.Call.status)
.join(models.Call, models.AnalysisResult.call_id == models.Call.id)
.order_by(models.AnalysisResult.created_at.desc())
.limit(50)
.all()
)
def _extract_title(stage3_json: str | None) -> str | None:
if not stage3_json:
return None
try:
parsed = json.loads(stage3_json)
except (ValueError, TypeError):
return None
recap = parsed.get("recap_card") if isinstance(parsed, dict) else None
if not isinstance(recap, dict):
return None
title = recap.get("title")
return title if isinstance(title, str) and title.strip() else None
return [
{
"call_id": r.AnalysisResult.call_id,
"blue_emotion": r.AnalysisResult.blue_emotion if r.status == "done" else None,
"pink_emotion": r.AnalysisResult.pink_emotion if r.status == "done" else None,
"garden_delta": r.AnalysisResult.garden_delta,
"created_at": r.AnalysisResult.created_at.isoformat() if r.AnalysisResult.created_at else None,
"status": r.status,
"recap_title": _extract_title(r.AnalysisResult.stage3_json) if r.status == "done" else None,
}
for r in results
]
@app.get("/api/calls/{call_id}")
def get_call(call_id: str, db: Session = Depends(get_db)):
result = db.query(models.AnalysisResult).filter(models.AnalysisResult.call_id == call_id).first()
if not result:
raise HTTPException(status_code=404, detail="Call not found")
return {
"call_id": result.call_id,
"result": json.loads(result.stage3_json),
"created_at": result.created_at.isoformat() if result.created_at else None,
}
# ─── Check-ins ──────────────────────────────────────────────
VALID_LEVELS = {"deep", "warm", "growing", "different", "learning"}
class CheckInCreate(BaseModel):
iso_date: str
score: int
level: str
my_mood: str
partner_guess: str
def validate_fields(self):
if not (0 <= self.score <= 100):
raise HTTPException(status_code=400, detail=f"Score must be 0-100, got {self.score}")
if self.level not in VALID_LEVELS:
raise HTTPException(status_code=400, detail=f"Invalid level: {self.level}")
@app.post("/api/checkins")
def create_checkin(data: CheckInCreate, db: Session = Depends(get_db)):
data.validate_fields()
checkin = models.CheckIn(
iso_date=data.iso_date,
score=data.score,
level=data.level,
my_mood=data.my_mood,
partner_guess=data.partner_guess,
)
db.add(checkin)
# Update garden
_update_garden_from_checkin(db, data.score)
db.commit()
db.refresh(checkin)
return {"status": "success", "id": checkin.id}
@app.get("/api/checkins")
def list_checkins(db: Session = Depends(get_db)):
records = db.query(models.CheckIn).order_by(models.CheckIn.created_at.desc()).limit(100).all()
return [
{
"id": r.id,
"iso_date": r.iso_date,
"score": r.score,
"level": r.level,
"my_mood": r.my_mood,
"partner_guess": r.partner_guess,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in records
]
# ─── Garden ─────────────────────────────────────────────────
# design-system.md Β§4.6 thresholds
def _compute_level(count: int) -> int:
if count >= 25: return 5
if count >= 15: return 4
if count >= 8: return 3
if count >= 3: return 2
return 1
def _get_or_create_garden(db: Session) -> models.GardenState:
garden = db.query(models.GardenState).filter(models.GardenState.id == 1).first()
if not garden:
garden = models.GardenState(id=1)
db.add(garden)
db.flush()
return garden
def _update_garden(db: Session, result_dict: dict):
garden = _get_or_create_garden(db)
garden.interaction_count += 1
garden.total_level = _compute_level(garden.interaction_count)
mood = result_dict.get("garden_update", {}).get("mood", "happy")
garden.last_mood = mood
def _update_garden_from_checkin(db: Session, score: int):
garden = _get_or_create_garden(db)
garden.interaction_count += 1
garden.total_level = _compute_level(garden.interaction_count)
garden.last_mood = "happy" if score >= 60 else "recovering"
@app.get("/api/garden")
def get_garden(db: Session = Depends(get_db)):
garden = _get_or_create_garden(db)
return {
"interaction_count": garden.interaction_count,
"total_level": garden.total_level,
"last_mood": garden.last_mood,
}
@app.put("/api/garden/interact")
def garden_interact(positive: bool = True, db: Session = Depends(get_db)):
garden = _get_or_create_garden(db)
garden.interaction_count += 1
garden.total_level = _compute_level(garden.interaction_count)
garden.last_mood = "happy" if positive else "recovering"
db.commit()
return {
"interaction_count": garden.interaction_count,
"total_level": garden.total_level,
"last_mood": garden.last_mood,
}