Anora-api / main.py
rairo's picture
Update main.py
2efb4c9 verified
# ============================================================
# main.py — AI Partner Cultural Simulator Backend (FULL)
# REST + WebSocket | Gemini + Firebase + Azure Pronunciation
# HuggingFace-compatible (Port 7860)
# ============================================================
import os
import json
import uuid
import math
import base64
import tempfile
import subprocess
import logging
import traceback
from datetime import datetime
from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_socketio import SocketIO, emit
import firebase_admin
from firebase_admin import credentials, db, auth
from google import genai
import azure.cognitiveservices.speech as speechsdk
# ------------------------------
# Logging (HF captures stdout)
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------
# App Init
# ------------------------------
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="eventlet")
# ============================================================
# 1) ENV & CLIENT INITIALIZATION (same naming as your template)
# ============================================================
# --- Firebase Initialization ---
try:
credentials_json_string = os.environ.get("FIREBASE")
if not credentials_json_string:
raise ValueError("The 'FIREBASE' environment variable is not set.")
credentials_json = json.loads(credentials_json_string)
firebase_db_url = os.environ.get("Firebase_DB")
if not firebase_db_url:
raise ValueError("The 'Firebase_DB' environment variable must be set.")
cred = credentials.Certificate(credentials_json)
firebase_admin.initialize_app(cred, {"databaseURL": firebase_db_url})
db_ref = db.reference()
logger.info("Firebase Admin SDK initialized successfully.")
except Exception as e:
logger.critical(f"FATAL: Error initializing Firebase: {e}")
logger.critical(traceback.format_exc())
raise
# --- Gemini Initialization ---
try:
gemini_api_key = os.environ.get("Gemini")
if not gemini_api_key:
raise ValueError("The 'Gemini' environment variable is not set.")
client = genai.Client(api_key=gemini_api_key)
MODEL_NAME = "gemini-2.0-flash"
logger.info(f"Gemini client initialized for model: {MODEL_NAME}")
except Exception as e:
logger.critical(f"FATAL: Error initializing Gemini: {e}")
logger.critical(traceback.format_exc())
raise
# --- Azure Speech ---
AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_KEY")
AZURE_SPEECH_REGION = os.environ.get("AZURE_SPEECH_REGION")
# ============================================================
# 2) LANGUAGE PACKS
# ============================================================
from korean import KOREAN_PACK
from english import ENGLISH_PACK
from japanese import JAPANESE_PACK
from german import GERMAN_PACK
LANGUAGE_PACKS = {
"ko-KR": KOREAN_PACK,
"en-US": ENGLISH_PACK,
"ja-JP": JAPANESE_PACK,
"de-DE": GERMAN_PACK,
}
# ============================================================
# 3) CORE HELPERS (same style as your existing app)
# ============================================================
def now_iso():
return datetime.utcnow().isoformat() + "Z"
def verify_token(auth_header):
if not auth_header or not auth_header.startswith("Bearer "):
return None
token = auth_header.split("Bearer ")[1]
try:
return auth.verify_id_token(token)["uid"]
except Exception as e:
logger.warning(f"Token verification failed: {e}")
return None
def verify_admin(auth_header):
uid = verify_token(auth_header)
if not uid:
raise PermissionError("Invalid or missing user token")
user_data = db_ref.child(f"users/{uid}").get()
if not user_data or not user_data.get("is_admin", False):
raise PermissionError("Admin access required")
return uid
def require_user():
uid = verify_token(request.headers.get("Authorization"))
if not uid:
raise PermissionError("Unauthorized")
return uid
def get_user(uid):
return db_ref.child(f"users/{uid}").get()
def update_user(uid, payload: dict):
db_ref.child(f"users/{uid}").update(payload)
# ============================================================
# 4) CREDITS (same vibe as Pitch Helper)
# ============================================================
START_SESSION_COST = 1
PRACTICE_ATTEMPT_COST = 1
PER_MINUTE_COST = 2
def charge(uid, amount, reason="charge"):
user_ref = db_ref.child(f"users/{uid}")
user = user_ref.get() or {}
current = int(user.get("credits", 0))
if current < amount:
raise ValueError("Insufficient credits")
new_total = max(0, current - int(amount))
user_ref.update({"credits": new_total})
# Optional: log credit usage (useful later)
db_ref.child(f"credit_ledger/{uid}").push().set({
"ts": now_iso(),
"delta": -int(amount),
"reason": reason,
"balance": new_total
})
return {"deducted": int(amount), "remaining": new_total}
# ============================================================
# 5) SESSION HELPERS
# ============================================================
def create_session(uid, language, scenario_id, title):
session_id = str(uuid.uuid4())
session = {
"sessionId": session_id,
"userId": uid,
"language": language,
"scenarioId": scenario_id,
"title": title,
"meters": {"respect": 50, "influence": 50, "trust": 50},
"turns": [],
"createdAt": now_iso(),
"endedAt": None,
"status": "active",
"struggleWords": {} # rolling avg per word
}
db_ref.child(f"sessions/{uid}/{session_id}").set(session)
return session
def get_session(uid, session_id):
return db_ref.child(f"sessions/{uid}/{session_id}").get()
def update_session(uid, session_id, payload: dict):
db_ref.child(f"sessions/{uid}/{session_id}").update(payload)
def list_sessions(uid):
data = db_ref.child(f"sessions/{uid}").get() or {}
# return as list sorted by createdAt desc
items = list(data.values())
items.sort(key=lambda x: x.get("createdAt", ""), reverse=True)
return items
# ============================================================
# 6) GEMINI — CULTURAL TURN EVALUATOR (JSON output)
# ============================================================
def _safe_json(text: str, fallback: dict):
try:
cleaned = (text or "").strip().lstrip("```json").rstrip("```").strip()
return json.loads(cleaned)
except Exception:
return fallback
def evaluate_turn(language_pack, scenario, transcript_turn, user_title):
prompt = f"""
You are a cultural authority evaluator and business communication coach.
IMMERSION TITLE: {user_title}
LANGUAGE: {language_pack["language"]}
SCENARIO: {scenario["name"]}
EXPECTATIONS (rules):
{json.dumps(scenario["rules"], ensure_ascii=False, indent=2)}
USER SAID:
"{transcript_turn}"
Return STRICT JSON ONLY:
{{
"meter_delta": {{
"respect": <int between -10 and 10>,
"influence": <int between -10 and 10>,
"trust": <int between -10 and 10>
}},
"feedback": "<one or two sentences, culturally grounded>",
"checkpoint_required": <true|false>
}}
"""
fallback = {
"meter_delta": {"respect": 0, "influence": 0, "trust": 0},
"feedback": "Evaluation unavailable.",
"checkpoint_required": False
}
try:
resp = client.models.generate_content(model=MODEL_NAME, contents=prompt)
return _safe_json(resp.text, fallback)
except Exception as e:
logger.error(f"Gemini evaluate_turn failed: {e}")
return fallback
# ============================================================
# 7) AUDIO SANITIZER (Azure requirement)
# ============================================================
def sanitize_audio(raw_path):
clean_path = raw_path + "_clean.wav"
cmd = [
"ffmpeg", "-y", "-v", "error",
"-i", raw_path,
"-ac", "1",
"-ar", "16000",
"-acodec", "pcm_s16le",
clean_path
]
subprocess.run(cmd, check=True)
return clean_path
def _azure_pronunciation_assess(reference_text, lang, wav_path):
speech_config = speechsdk.SpeechConfig(subscription=AZURE_SPEECH_KEY, region=AZURE_SPEECH_REGION)
speech_config.speech_recognition_language = lang
audio_config = speechsdk.audio.AudioConfig(filename=wav_path)
pronunciation_config = speechsdk.PronunciationAssessmentConfig(
reference_text=reference_text,
grading_system=speechsdk.PronunciationAssessmentGradingSystem.HundredMark,
granularity=speechsdk.PronunciationAssessmentGranularity.Word,
enable_miscue=True
)
recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
pronunciation_config.apply_to(recognizer)
result = recognizer.recognize_once_async().get()
if result.reason != speechsdk.ResultReason.RecognizedSpeech:
return {
"success": False,
"score": 0,
"fluency": 0,
"completeness": 0,
"recognized_text": getattr(result, "text", "") or "No match",
"word_details": []
}
pron_result = speechsdk.PronunciationAssessmentResult(result)
detailed_words = []
for w in pron_result.words:
detailed_words.append({
"word": w.word,
"score": w.accuracy_score,
"error": w.error_type
})
return {
"success": True,
"score": pron_result.accuracy_score,
"fluency": pron_result.fluency_score,
"completeness": pron_result.completeness_score,
"recognized_text": result.text,
"word_details": detailed_words
}
# ============================================================
# 8) AUTH + PROFILE ENDPOINTS (ported from your template)
# ============================================================
@app.route("/api/auth/signup", methods=["POST"])
def signup():
try:
data = request.get_json() or {}
email = data.get("email")
password = data.get("password")
display_name = data.get("displayName")
if not email or not password:
return jsonify({"error": "Email and password are required"}), 400
user = auth.create_user(email=email, password=password, display_name=display_name)
user_data = {
"email": email,
"displayName": display_name,
"credits": 30,
"is_admin": False,
"createdAt": now_iso()
}
db_ref.child(f"users/{user.uid}").set(user_data)
return jsonify({"success": True, "uid": user.uid, **user_data}), 201
except Exception as e:
logger.error(f"Signup failed: {e}")
if "EMAIL_EXISTS" in str(e):
return jsonify({"error": "An account with this email already exists."}), 409
return jsonify({"error": str(e)}), 400
@app.route("/api/auth/social-signin", methods=["POST"])
def social_signin():
uid = verify_token(request.headers.get("Authorization"))
if not uid:
return jsonify({"error": "Invalid or expired token"}), 401
user_ref = db_ref.child(f"users/{uid}")
user_data = user_ref.get()
if user_data:
return jsonify({"uid": uid, **user_data}), 200
try:
firebase_user = auth.get_user(uid)
new_user_data = {
"email": firebase_user.email,
"displayName": firebase_user.display_name,
"credits": 30,
"is_admin": False,
"createdAt": now_iso()
}
user_ref.set(new_user_data)
return jsonify({"success": True, "uid": uid, **new_user_data}), 201
except Exception as e:
logger.error(f"Failed to create profile for social user {uid}: {e}")
return jsonify({"error": f"Failed to create user profile: {str(e)}"}), 500
@app.route("/api/user/profile", methods=["GET"])
def get_user_profile():
uid = verify_token(request.headers.get("Authorization"))
if not uid:
return jsonify({"error": "Invalid or expired token"}), 401
user_data = db_ref.child(f"users/{uid}").get()
if not user_data:
return jsonify({"error": "User not found"}), 404
return jsonify({"uid": uid, **user_data}), 200
@app.route("/api/user/profile", methods=["PATCH"])
def update_user_profile():
uid = verify_token(request.headers.get("Authorization"))
if not uid:
return jsonify({"error": "Invalid or expired token"}), 401
data = request.get_json() or {}
allowed = {}
# keep it simple + safe
if "displayName" in data and isinstance(data["displayName"], str):
allowed["displayName"] = data["displayName"].strip()
if "preferredLanguage" in data and isinstance(data["preferredLanguage"], str):
allowed["preferredLanguage"] = data["preferredLanguage"].strip()
if not allowed:
return jsonify({"error": "No valid fields to update"}), 400
update_user(uid, allowed)
user_data = get_user(uid) or {}
return jsonify({"success": True, "uid": uid, **user_data}), 200
# ============================================================
# 9) CREDITS + ADMIN ENDPOINTS (same as template)
# ============================================================
@app.route("/api/user/request-credits", methods=["POST"])
def request_credits():
uid = verify_token(request.headers.get("Authorization"))
if not uid:
return jsonify({"error": "Unauthorized"}), 401
try:
data = request.get_json() or {}
if "requested_credits" not in data:
return jsonify({"error": "requested_credits is required"}), 400
req_ref = db_ref.child("credit_requests").push()
req_ref.set({
"requestId": req_ref.key,
"userId": uid,
"requested_credits": int(data["requested_credits"]),
"status": "pending",
"requestedAt": now_iso()
})
return jsonify({"success": True, "requestId": req_ref.key}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/admin/credit_requests", methods=["GET"])
def list_credit_requests():
try:
verify_admin(request.headers.get("Authorization"))
requests_data = db_ref.child("credit_requests").get() or {}
return jsonify(list(requests_data.values())), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 403
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/admin/credit_requests/<string:request_id>", methods=["PUT"])
def process_credit_request(request_id):
try:
admin_uid = verify_admin(request.headers.get("Authorization"))
req_ref = db_ref.child(f"credit_requests/{request_id}")
req_data = req_ref.get()
if not req_data:
return jsonify({"error": "Credit request not found"}), 404
decision = (request.get_json() or {}).get("decision")
if decision not in ["approved", "declined"]:
return jsonify({"error": 'Decision must be "approved" or "declined"'}), 400
if decision == "approved":
user_ref = db_ref.child(f"users/{req_data['userId']}")
user_data = user_ref.get() or {}
new_total = int(user_data.get("credits", 0)) + int(req_data.get("requested_credits", 0))
user_ref.update({"credits": new_total})
db_ref.child(f"credit_ledger/{req_data['userId']}").push().set({
"ts": now_iso(),
"delta": int(req_data.get("requested_credits", 0)),
"reason": "admin_credit_approval",
"balance": new_total,
"processedBy": admin_uid
})
req_ref.update({
"status": decision,
"processedBy": admin_uid,
"processedAt": now_iso()
})
return jsonify({"success": True, "message": f"Request {decision}."}), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 403
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/admin/users/<string:uid>/credits", methods=["PUT"])
def admin_update_credits(uid):
try:
verify_admin(request.headers.get("Authorization"))
add_credits = (request.get_json() or {}).get("add_credits")
if add_credits is None:
return jsonify({"error": "add_credits is required"}), 400
user_ref = db_ref.child(f"users/{uid}")
user_data = user_ref.get()
if not user_data:
return jsonify({"error": "User not found"}), 404
new_total = int(user_data.get("credits", 0)) + int(add_credits)
user_ref.update({"credits": new_total})
db_ref.child(f"credit_ledger/{uid}").push().set({
"ts": now_iso(),
"delta": int(add_credits),
"reason": "admin_manual_adjust",
"balance": new_total
})
return jsonify({"success": True, "new_total_credits": new_total}), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 403
except Exception as e:
return jsonify({"error": str(e)}), 500
# ============================================================
# 10) SESSION ENDPOINTS (your new app core)
# ============================================================
@app.route("/api/session/start", methods=["POST"])
def start_session():
try:
uid = require_user()
data = request.get_json() or {}
language = data.get("language")
scenario_id = data.get("scenarioId")
if not language or not scenario_id:
return jsonify({"error": "language and scenarioId are required"}), 400
if language not in LANGUAGE_PACKS:
return jsonify({"error": f"Unsupported language: {language}"}), 400
pack = LANGUAGE_PACKS[language]
scenario = pack["scenarios"].get(scenario_id)
if not scenario:
return jsonify({"error": "Invalid scenarioId"}), 400
title = scenario["title"]
credit_info = charge(uid, START_SESSION_COST, reason="start_session")
session = create_session(uid, language, scenario_id, title)
# Client passes these to ElevenLabs agent init as dynamic vars
dynamic_vars = {
"title": title,
"language": pack["language"],
"scenarioName": scenario["name"],
"scenarioId": scenario_id
}
return jsonify({
"session": session,
"dynamicVariables": dynamic_vars,
"credits": credit_info
}), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 401
except Exception as e:
logger.error(f"start_session failed: {e}")
logger.error(traceback.format_exc())
return jsonify({"error": str(e)}), 400
@app.route("/api/session/turn", methods=["POST"])
def submit_turn():
try:
uid = require_user()
data = request.get_json() or {}
session_id = data.get("sessionId")
transcript = data.get("transcript")
if not session_id or not transcript:
return jsonify({"error": "sessionId and transcript are required"}), 400
session = get_session(uid, session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
if session.get("status") != "active":
return jsonify({"error": "Session is not active"}), 400
pack = LANGUAGE_PACKS[session["language"]]
scenario = pack["scenarios"][session["scenarioId"]]
title = session.get("title", scenario.get("title"))
result = evaluate_turn(pack, scenario, transcript, title)
meters = session.get("meters", {"respect": 50, "influence": 50, "trust": 50})
deltas = result.get("meter_delta", {"respect": 0, "influence": 0, "trust": 0})
for k in meters:
meters[k] = max(0, min(100, int(meters[k]) + int(deltas.get(k, 0))))
turns = session.get("turns", [])
turns.append({
"id": str(uuid.uuid4()),
"at": now_iso(),
"text": transcript,
"feedback": result.get("feedback", "")
})
update_session(uid, session_id, {
"meters": meters,
"turns": turns
})
return jsonify({
"meters": meters,
"feedback": result.get("feedback", ""),
"checkpointRequired": bool(result.get("checkpoint_required", False))
}), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 401
except Exception as e:
logger.error(f"submit_turn failed: {e}")
logger.error(traceback.format_exc())
return jsonify({"error": str(e)}), 400
@app.route("/api/session/end", methods=["POST"])
def end_session():
try:
uid = require_user()
data = request.get_json() or {}
session_id = data.get("sessionId")
duration = data.get("durationSeconds")
if not session_id or not isinstance(duration, (int, float)):
return jsonify({"error": "sessionId and durationSeconds are required"}), 400
session = get_session(uid, session_id)
if not session:
return jsonify({"error": "Session not found"}), 404
cost = math.ceil(float(duration) / 60.0) * PER_MINUTE_COST
credit_info = charge(uid, cost, reason="session_minutes")
update_session(uid, session_id, {
"status": "completed",
"endedAt": now_iso(),
"durationSeconds": duration,
"minuteCost": cost
})
return jsonify({
"status": "completed",
"cost": cost,
"credits": credit_info
}), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 401
except Exception as e:
logger.error(f"end_session failed: {e}")
logger.error(traceback.format_exc())
return jsonify({"error": str(e)}), 400
@app.route("/api/sessions", methods=["GET"])
def api_list_sessions():
try:
uid = require_user()
return jsonify(list_sessions(uid)), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 401
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/sessions/<string:session_id>", methods=["GET"])
def api_get_session(session_id):
try:
uid = require_user()
s = get_session(uid, session_id)
if not s:
return jsonify({"error": "Session not found"}), 404
return jsonify(s), 200
except PermissionError as e:
return jsonify({"error": str(e)}), 401
except Exception as e:
return jsonify({"error": str(e)}), 500
# ============================================================
# 11) WEBSOCKET — PRONUNCIATION (practice + live turn)
# ============================================================
def _update_struggle_words(session, word_details):
"""
Rolling average per word.
session["struggleWords"] shape:
{ "word": {"avg": float, "count": int} }
"""
struggle = session.get("struggleWords", {}) or {}
for wd in word_details:
w = (wd.get("word") or "").strip()
if not w:
continue
score = float(wd.get("score") or 0)
entry = struggle.get(w, {"avg": 0.0, "count": 0})
n = int(entry.get("count", 0))
avg = float(entry.get("avg", 0.0))
new_avg = (avg * n + score) / (n + 1)
struggle[w] = {"avg": new_avg, "count": n + 1}
return struggle
@socketio.on("practice_pronunciation")
def practice_pronunciation(data):
"""
Practice loop: reference text must be supplied.
Optional:
- authToken: Firebase ID token (so we can charge credits)
- chargeCredits: true/false (default false)
"""
raw_path = None
clean_path = None
try:
ref_text = data.get("text")
lang = data.get("lang", "en-US")
audio = data.get("audio")
if not ref_text or not audio:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing text/audio"})
return
# Optional credit charge for practice
if data.get("chargeCredits", False):
auth_token = data.get("authToken")
if not auth_token:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing authToken"})
return
uid = verify_token(f"Bearer {auth_token}")
if not uid:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"})
return
try:
credit_info = charge(uid, PRACTICE_ATTEMPT_COST, reason="practice_attempt")
except Exception:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Insufficient credits"})
return
else:
credit_info = None
# Decode base64 audio
audio_b64 = audio.split(",")[1] if "," in audio else audio
audio_bytes = base64.b64decode(audio_b64)
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
f.write(audio_bytes)
raw_path = f.name
clean_path = sanitize_audio(raw_path)
result = _azure_pronunciation_assess(ref_text, lang, clean_path)
if credit_info:
result["credits"] = credit_info
emit("pronunciation_result", result)
except Exception as e:
logger.error(f"practice_pronunciation failed: {e}")
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"})
finally:
try:
if raw_path and os.path.exists(raw_path):
os.remove(raw_path)
if clean_path and os.path.exists(clean_path):
os.remove(clean_path)
except Exception:
pass
@socketio.on("live_pronunciation_turn")
def live_pronunciation_turn(data):
"""
Live scoring for a session turn.
Requires:
- authToken (Firebase ID token)
- sessionId
- text (reference phrase OR checkpoint line)
- lang
- audio
Returns word_details + updated struggle words top list.
"""
raw_path = None
clean_path = None
try:
auth_token = data.get("authToken")
session_id = data.get("sessionId")
ref_text = data.get("text")
lang = data.get("lang", "en-US")
audio = data.get("audio")
if not auth_token or not session_id or not ref_text or not audio:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing fields"})
return
uid = verify_token(f"Bearer {auth_token}")
if not uid:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"})
return
session = get_session(uid, session_id)
if not session:
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Session not found"})
return
audio_b64 = audio.split(",")[1] if "," in audio else audio
audio_bytes = base64.b64decode(audio_b64)
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
f.write(audio_bytes)
raw_path = f.name
clean_path = sanitize_audio(raw_path)
result = _azure_pronunciation_assess(ref_text, lang, clean_path)
# update rolling struggle words
struggle = _update_struggle_words(session, result.get("word_details", []))
update_session(uid, session_id, {"struggleWords": struggle})
# top 8 worst avg
top = sorted(
[{"word": w, "avg": v["avg"], "count": v["count"]} for w, v in struggle.items()],
key=lambda x: x["avg"]
)[:8]
result["struggle_top"] = top
emit("pronunciation_result", result)
except Exception as e:
logger.error(f"live_pronunciation_turn failed: {e}")
emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"})
finally:
try:
if raw_path and os.path.exists(raw_path):
os.remove(raw_path)
if clean_path and os.path.exists(clean_path):
os.remove(clean_path)
except Exception:
pass
# ============================================================
# 12) MAIN
# ============================================================
if __name__ == "__main__":
socketio.run(app, host="0.0.0.0", port=7860)