# ============================================================ # main.py — AI Partner Cultural Simulator Backend (FULL) # REST + WebSocket | Gemini + Firebase + Azure Pronunciation # HuggingFace-compatible (Port 7860) # ============================================================ import os import json import uuid import math import base64 import tempfile import subprocess import logging import traceback from datetime import datetime from flask import Flask, request, jsonify from flask_cors import CORS from flask_socketio import SocketIO, emit import firebase_admin from firebase_admin import credentials, db, auth from google import genai import azure.cognitiveservices.speech as speechsdk # ------------------------------ # Logging (HF captures stdout) # ------------------------------ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ------------------------------ # App Init # ------------------------------ app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*", async_mode="eventlet") # ============================================================ # 1) ENV & CLIENT INITIALIZATION (same naming as your template) # ============================================================ # --- Firebase Initialization --- try: credentials_json_string = os.environ.get("FIREBASE") if not credentials_json_string: raise ValueError("The 'FIREBASE' environment variable is not set.") credentials_json = json.loads(credentials_json_string) firebase_db_url = os.environ.get("Firebase_DB") if not firebase_db_url: raise ValueError("The 'Firebase_DB' environment variable must be set.") cred = credentials.Certificate(credentials_json) firebase_admin.initialize_app(cred, {"databaseURL": firebase_db_url}) db_ref = db.reference() logger.info("Firebase Admin SDK initialized successfully.") except Exception as e: logger.critical(f"FATAL: Error initializing Firebase: {e}") logger.critical(traceback.format_exc()) raise # --- Gemini Initialization --- try: gemini_api_key = os.environ.get("Gemini") if not gemini_api_key: raise ValueError("The 'Gemini' environment variable is not set.") client = genai.Client(api_key=gemini_api_key) MODEL_NAME = "gemini-2.0-flash" logger.info(f"Gemini client initialized for model: {MODEL_NAME}") except Exception as e: logger.critical(f"FATAL: Error initializing Gemini: {e}") logger.critical(traceback.format_exc()) raise # --- Azure Speech --- AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_KEY") AZURE_SPEECH_REGION = os.environ.get("AZURE_SPEECH_REGION") # ============================================================ # 2) LANGUAGE PACKS # ============================================================ from korean import KOREAN_PACK from english import ENGLISH_PACK from japanese import JAPANESE_PACK from german import GERMAN_PACK LANGUAGE_PACKS = { "ko-KR": KOREAN_PACK, "en-US": ENGLISH_PACK, "ja-JP": JAPANESE_PACK, "de-DE": GERMAN_PACK, } # ============================================================ # 3) CORE HELPERS (same style as your existing app) # ============================================================ def now_iso(): return datetime.utcnow().isoformat() + "Z" def verify_token(auth_header): if not auth_header or not auth_header.startswith("Bearer "): return None token = auth_header.split("Bearer ")[1] try: return auth.verify_id_token(token)["uid"] except Exception as e: logger.warning(f"Token verification failed: {e}") return None def verify_admin(auth_header): uid = verify_token(auth_header) if not uid: raise PermissionError("Invalid or missing user token") user_data = db_ref.child(f"users/{uid}").get() if not user_data or not user_data.get("is_admin", False): raise PermissionError("Admin access required") return uid def require_user(): uid = verify_token(request.headers.get("Authorization")) if not uid: raise PermissionError("Unauthorized") return uid def get_user(uid): return db_ref.child(f"users/{uid}").get() def update_user(uid, payload: dict): db_ref.child(f"users/{uid}").update(payload) # ============================================================ # 4) CREDITS (same vibe as Pitch Helper) # ============================================================ START_SESSION_COST = 1 PRACTICE_ATTEMPT_COST = 1 PER_MINUTE_COST = 2 def charge(uid, amount, reason="charge"): user_ref = db_ref.child(f"users/{uid}") user = user_ref.get() or {} current = int(user.get("credits", 0)) if current < amount: raise ValueError("Insufficient credits") new_total = max(0, current - int(amount)) user_ref.update({"credits": new_total}) # Optional: log credit usage (useful later) db_ref.child(f"credit_ledger/{uid}").push().set({ "ts": now_iso(), "delta": -int(amount), "reason": reason, "balance": new_total }) return {"deducted": int(amount), "remaining": new_total} # ============================================================ # 5) SESSION HELPERS # ============================================================ def create_session(uid, language, scenario_id, title): session_id = str(uuid.uuid4()) session = { "sessionId": session_id, "userId": uid, "language": language, "scenarioId": scenario_id, "title": title, "meters": {"respect": 50, "influence": 50, "trust": 50}, "turns": [], "createdAt": now_iso(), "endedAt": None, "status": "active", "struggleWords": {} # rolling avg per word } db_ref.child(f"sessions/{uid}/{session_id}").set(session) return session def get_session(uid, session_id): return db_ref.child(f"sessions/{uid}/{session_id}").get() def update_session(uid, session_id, payload: dict): db_ref.child(f"sessions/{uid}/{session_id}").update(payload) def list_sessions(uid): data = db_ref.child(f"sessions/{uid}").get() or {} # return as list sorted by createdAt desc items = list(data.values()) items.sort(key=lambda x: x.get("createdAt", ""), reverse=True) return items # ============================================================ # 6) GEMINI — CULTURAL TURN EVALUATOR (JSON output) # ============================================================ def _safe_json(text: str, fallback: dict): try: cleaned = (text or "").strip().lstrip("```json").rstrip("```").strip() return json.loads(cleaned) except Exception: return fallback def evaluate_turn(language_pack, scenario, transcript_turn, user_title): prompt = f""" You are a cultural authority evaluator and business communication coach. IMMERSION TITLE: {user_title} LANGUAGE: {language_pack["language"]} SCENARIO: {scenario["name"]} EXPECTATIONS (rules): {json.dumps(scenario["rules"], ensure_ascii=False, indent=2)} USER SAID: "{transcript_turn}" Return STRICT JSON ONLY: {{ "meter_delta": {{ "respect": , "influence": , "trust": }}, "feedback": "", "checkpoint_required": }} """ fallback = { "meter_delta": {"respect": 0, "influence": 0, "trust": 0}, "feedback": "Evaluation unavailable.", "checkpoint_required": False } try: resp = client.models.generate_content(model=MODEL_NAME, contents=prompt) return _safe_json(resp.text, fallback) except Exception as e: logger.error(f"Gemini evaluate_turn failed: {e}") return fallback # ============================================================ # 7) AUDIO SANITIZER (Azure requirement) # ============================================================ def sanitize_audio(raw_path): clean_path = raw_path + "_clean.wav" cmd = [ "ffmpeg", "-y", "-v", "error", "-i", raw_path, "-ac", "1", "-ar", "16000", "-acodec", "pcm_s16le", clean_path ] subprocess.run(cmd, check=True) return clean_path def _azure_pronunciation_assess(reference_text, lang, wav_path): speech_config = speechsdk.SpeechConfig(subscription=AZURE_SPEECH_KEY, region=AZURE_SPEECH_REGION) speech_config.speech_recognition_language = lang audio_config = speechsdk.audio.AudioConfig(filename=wav_path) pronunciation_config = speechsdk.PronunciationAssessmentConfig( reference_text=reference_text, grading_system=speechsdk.PronunciationAssessmentGradingSystem.HundredMark, granularity=speechsdk.PronunciationAssessmentGranularity.Word, enable_miscue=True ) recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config) pronunciation_config.apply_to(recognizer) result = recognizer.recognize_once_async().get() if result.reason != speechsdk.ResultReason.RecognizedSpeech: return { "success": False, "score": 0, "fluency": 0, "completeness": 0, "recognized_text": getattr(result, "text", "") or "No match", "word_details": [] } pron_result = speechsdk.PronunciationAssessmentResult(result) detailed_words = [] for w in pron_result.words: detailed_words.append({ "word": w.word, "score": w.accuracy_score, "error": w.error_type }) return { "success": True, "score": pron_result.accuracy_score, "fluency": pron_result.fluency_score, "completeness": pron_result.completeness_score, "recognized_text": result.text, "word_details": detailed_words } # ============================================================ # 8) AUTH + PROFILE ENDPOINTS (ported from your template) # ============================================================ @app.route("/api/auth/signup", methods=["POST"]) def signup(): try: data = request.get_json() or {} email = data.get("email") password = data.get("password") display_name = data.get("displayName") if not email or not password: return jsonify({"error": "Email and password are required"}), 400 user = auth.create_user(email=email, password=password, display_name=display_name) user_data = { "email": email, "displayName": display_name, "credits": 30, "is_admin": False, "createdAt": now_iso() } db_ref.child(f"users/{user.uid}").set(user_data) return jsonify({"success": True, "uid": user.uid, **user_data}), 201 except Exception as e: logger.error(f"Signup failed: {e}") if "EMAIL_EXISTS" in str(e): return jsonify({"error": "An account with this email already exists."}), 409 return jsonify({"error": str(e)}), 400 @app.route("/api/auth/social-signin", methods=["POST"]) def social_signin(): uid = verify_token(request.headers.get("Authorization")) if not uid: return jsonify({"error": "Invalid or expired token"}), 401 user_ref = db_ref.child(f"users/{uid}") user_data = user_ref.get() if user_data: return jsonify({"uid": uid, **user_data}), 200 try: firebase_user = auth.get_user(uid) new_user_data = { "email": firebase_user.email, "displayName": firebase_user.display_name, "credits": 30, "is_admin": False, "createdAt": now_iso() } user_ref.set(new_user_data) return jsonify({"success": True, "uid": uid, **new_user_data}), 201 except Exception as e: logger.error(f"Failed to create profile for social user {uid}: {e}") return jsonify({"error": f"Failed to create user profile: {str(e)}"}), 500 @app.route("/api/user/profile", methods=["GET"]) def get_user_profile(): uid = verify_token(request.headers.get("Authorization")) if not uid: return jsonify({"error": "Invalid or expired token"}), 401 user_data = db_ref.child(f"users/{uid}").get() if not user_data: return jsonify({"error": "User not found"}), 404 return jsonify({"uid": uid, **user_data}), 200 @app.route("/api/user/profile", methods=["PATCH"]) def update_user_profile(): uid = verify_token(request.headers.get("Authorization")) if not uid: return jsonify({"error": "Invalid or expired token"}), 401 data = request.get_json() or {} allowed = {} # keep it simple + safe if "displayName" in data and isinstance(data["displayName"], str): allowed["displayName"] = data["displayName"].strip() if "preferredLanguage" in data and isinstance(data["preferredLanguage"], str): allowed["preferredLanguage"] = data["preferredLanguage"].strip() if not allowed: return jsonify({"error": "No valid fields to update"}), 400 update_user(uid, allowed) user_data = get_user(uid) or {} return jsonify({"success": True, "uid": uid, **user_data}), 200 # ============================================================ # 9) CREDITS + ADMIN ENDPOINTS (same as template) # ============================================================ @app.route("/api/user/request-credits", methods=["POST"]) def request_credits(): uid = verify_token(request.headers.get("Authorization")) if not uid: return jsonify({"error": "Unauthorized"}), 401 try: data = request.get_json() or {} if "requested_credits" not in data: return jsonify({"error": "requested_credits is required"}), 400 req_ref = db_ref.child("credit_requests").push() req_ref.set({ "requestId": req_ref.key, "userId": uid, "requested_credits": int(data["requested_credits"]), "status": "pending", "requestedAt": now_iso() }) return jsonify({"success": True, "requestId": req_ref.key}), 200 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/admin/credit_requests", methods=["GET"]) def list_credit_requests(): try: verify_admin(request.headers.get("Authorization")) requests_data = db_ref.child("credit_requests").get() or {} return jsonify(list(requests_data.values())), 200 except PermissionError as e: return jsonify({"error": str(e)}), 403 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/admin/credit_requests/", methods=["PUT"]) def process_credit_request(request_id): try: admin_uid = verify_admin(request.headers.get("Authorization")) req_ref = db_ref.child(f"credit_requests/{request_id}") req_data = req_ref.get() if not req_data: return jsonify({"error": "Credit request not found"}), 404 decision = (request.get_json() or {}).get("decision") if decision not in ["approved", "declined"]: return jsonify({"error": 'Decision must be "approved" or "declined"'}), 400 if decision == "approved": user_ref = db_ref.child(f"users/{req_data['userId']}") user_data = user_ref.get() or {} new_total = int(user_data.get("credits", 0)) + int(req_data.get("requested_credits", 0)) user_ref.update({"credits": new_total}) db_ref.child(f"credit_ledger/{req_data['userId']}").push().set({ "ts": now_iso(), "delta": int(req_data.get("requested_credits", 0)), "reason": "admin_credit_approval", "balance": new_total, "processedBy": admin_uid }) req_ref.update({ "status": decision, "processedBy": admin_uid, "processedAt": now_iso() }) return jsonify({"success": True, "message": f"Request {decision}."}), 200 except PermissionError as e: return jsonify({"error": str(e)}), 403 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/admin/users//credits", methods=["PUT"]) def admin_update_credits(uid): try: verify_admin(request.headers.get("Authorization")) add_credits = (request.get_json() or {}).get("add_credits") if add_credits is None: return jsonify({"error": "add_credits is required"}), 400 user_ref = db_ref.child(f"users/{uid}") user_data = user_ref.get() if not user_data: return jsonify({"error": "User not found"}), 404 new_total = int(user_data.get("credits", 0)) + int(add_credits) user_ref.update({"credits": new_total}) db_ref.child(f"credit_ledger/{uid}").push().set({ "ts": now_iso(), "delta": int(add_credits), "reason": "admin_manual_adjust", "balance": new_total }) return jsonify({"success": True, "new_total_credits": new_total}), 200 except PermissionError as e: return jsonify({"error": str(e)}), 403 except Exception as e: return jsonify({"error": str(e)}), 500 # ============================================================ # 10) SESSION ENDPOINTS (your new app core) # ============================================================ @app.route("/api/session/start", methods=["POST"]) def start_session(): try: uid = require_user() data = request.get_json() or {} language = data.get("language") scenario_id = data.get("scenarioId") if not language or not scenario_id: return jsonify({"error": "language and scenarioId are required"}), 400 if language not in LANGUAGE_PACKS: return jsonify({"error": f"Unsupported language: {language}"}), 400 pack = LANGUAGE_PACKS[language] scenario = pack["scenarios"].get(scenario_id) if not scenario: return jsonify({"error": "Invalid scenarioId"}), 400 title = scenario["title"] credit_info = charge(uid, START_SESSION_COST, reason="start_session") session = create_session(uid, language, scenario_id, title) # Client passes these to ElevenLabs agent init as dynamic vars dynamic_vars = { "title": title, "language": pack["language"], "scenarioName": scenario["name"], "scenarioId": scenario_id } return jsonify({ "session": session, "dynamicVariables": dynamic_vars, "credits": credit_info }), 200 except PermissionError as e: return jsonify({"error": str(e)}), 401 except Exception as e: logger.error(f"start_session failed: {e}") logger.error(traceback.format_exc()) return jsonify({"error": str(e)}), 400 @app.route("/api/session/turn", methods=["POST"]) def submit_turn(): try: uid = require_user() data = request.get_json() or {} session_id = data.get("sessionId") transcript = data.get("transcript") if not session_id or not transcript: return jsonify({"error": "sessionId and transcript are required"}), 400 session = get_session(uid, session_id) if not session: return jsonify({"error": "Session not found"}), 404 if session.get("status") != "active": return jsonify({"error": "Session is not active"}), 400 pack = LANGUAGE_PACKS[session["language"]] scenario = pack["scenarios"][session["scenarioId"]] title = session.get("title", scenario.get("title")) result = evaluate_turn(pack, scenario, transcript, title) meters = session.get("meters", {"respect": 50, "influence": 50, "trust": 50}) deltas = result.get("meter_delta", {"respect": 0, "influence": 0, "trust": 0}) for k in meters: meters[k] = max(0, min(100, int(meters[k]) + int(deltas.get(k, 0)))) turns = session.get("turns", []) turns.append({ "id": str(uuid.uuid4()), "at": now_iso(), "text": transcript, "feedback": result.get("feedback", "") }) update_session(uid, session_id, { "meters": meters, "turns": turns }) return jsonify({ "meters": meters, "feedback": result.get("feedback", ""), "checkpointRequired": bool(result.get("checkpoint_required", False)) }), 200 except PermissionError as e: return jsonify({"error": str(e)}), 401 except Exception as e: logger.error(f"submit_turn failed: {e}") logger.error(traceback.format_exc()) return jsonify({"error": str(e)}), 400 @app.route("/api/session/end", methods=["POST"]) def end_session(): try: uid = require_user() data = request.get_json() or {} session_id = data.get("sessionId") duration = data.get("durationSeconds") if not session_id or not isinstance(duration, (int, float)): return jsonify({"error": "sessionId and durationSeconds are required"}), 400 session = get_session(uid, session_id) if not session: return jsonify({"error": "Session not found"}), 404 cost = math.ceil(float(duration) / 60.0) * PER_MINUTE_COST credit_info = charge(uid, cost, reason="session_minutes") update_session(uid, session_id, { "status": "completed", "endedAt": now_iso(), "durationSeconds": duration, "minuteCost": cost }) return jsonify({ "status": "completed", "cost": cost, "credits": credit_info }), 200 except PermissionError as e: return jsonify({"error": str(e)}), 401 except Exception as e: logger.error(f"end_session failed: {e}") logger.error(traceback.format_exc()) return jsonify({"error": str(e)}), 400 @app.route("/api/sessions", methods=["GET"]) def api_list_sessions(): try: uid = require_user() return jsonify(list_sessions(uid)), 200 except PermissionError as e: return jsonify({"error": str(e)}), 401 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/sessions/", methods=["GET"]) def api_get_session(session_id): try: uid = require_user() s = get_session(uid, session_id) if not s: return jsonify({"error": "Session not found"}), 404 return jsonify(s), 200 except PermissionError as e: return jsonify({"error": str(e)}), 401 except Exception as e: return jsonify({"error": str(e)}), 500 # ============================================================ # 11) WEBSOCKET — PRONUNCIATION (practice + live turn) # ============================================================ def _update_struggle_words(session, word_details): """ Rolling average per word. session["struggleWords"] shape: { "word": {"avg": float, "count": int} } """ struggle = session.get("struggleWords", {}) or {} for wd in word_details: w = (wd.get("word") or "").strip() if not w: continue score = float(wd.get("score") or 0) entry = struggle.get(w, {"avg": 0.0, "count": 0}) n = int(entry.get("count", 0)) avg = float(entry.get("avg", 0.0)) new_avg = (avg * n + score) / (n + 1) struggle[w] = {"avg": new_avg, "count": n + 1} return struggle @socketio.on("practice_pronunciation") def practice_pronunciation(data): """ Practice loop: reference text must be supplied. Optional: - authToken: Firebase ID token (so we can charge credits) - chargeCredits: true/false (default false) """ raw_path = None clean_path = None try: ref_text = data.get("text") lang = data.get("lang", "en-US") audio = data.get("audio") if not ref_text or not audio: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing text/audio"}) return # Optional credit charge for practice if data.get("chargeCredits", False): auth_token = data.get("authToken") if not auth_token: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing authToken"}) return uid = verify_token(f"Bearer {auth_token}") if not uid: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"}) return try: credit_info = charge(uid, PRACTICE_ATTEMPT_COST, reason="practice_attempt") except Exception: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Insufficient credits"}) return else: credit_info = None # Decode base64 audio audio_b64 = audio.split(",")[1] if "," in audio else audio audio_bytes = base64.b64decode(audio_b64) with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f: f.write(audio_bytes) raw_path = f.name clean_path = sanitize_audio(raw_path) result = _azure_pronunciation_assess(ref_text, lang, clean_path) if credit_info: result["credits"] = credit_info emit("pronunciation_result", result) except Exception as e: logger.error(f"practice_pronunciation failed: {e}") emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"}) finally: try: if raw_path and os.path.exists(raw_path): os.remove(raw_path) if clean_path and os.path.exists(clean_path): os.remove(clean_path) except Exception: pass @socketio.on("live_pronunciation_turn") def live_pronunciation_turn(data): """ Live scoring for a session turn. Requires: - authToken (Firebase ID token) - sessionId - text (reference phrase OR checkpoint line) - lang - audio Returns word_details + updated struggle words top list. """ raw_path = None clean_path = None try: auth_token = data.get("authToken") session_id = data.get("sessionId") ref_text = data.get("text") lang = data.get("lang", "en-US") audio = data.get("audio") if not auth_token or not session_id or not ref_text or not audio: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing fields"}) return uid = verify_token(f"Bearer {auth_token}") if not uid: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"}) return session = get_session(uid, session_id) if not session: emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Session not found"}) return audio_b64 = audio.split(",")[1] if "," in audio else audio audio_bytes = base64.b64decode(audio_b64) with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f: f.write(audio_bytes) raw_path = f.name clean_path = sanitize_audio(raw_path) result = _azure_pronunciation_assess(ref_text, lang, clean_path) # update rolling struggle words struggle = _update_struggle_words(session, result.get("word_details", [])) update_session(uid, session_id, {"struggleWords": struggle}) # top 8 worst avg top = sorted( [{"word": w, "avg": v["avg"], "count": v["count"]} for w, v in struggle.items()], key=lambda x: x["avg"] )[:8] result["struggle_top"] = top emit("pronunciation_result", result) except Exception as e: logger.error(f"live_pronunciation_turn failed: {e}") emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"}) finally: try: if raw_path and os.path.exists(raw_path): os.remove(raw_path) if clean_path and os.path.exists(clean_path): os.remove(clean_path) except Exception: pass # ============================================================ # 12) MAIN # ============================================================ if __name__ == "__main__": socketio.run(app, host="0.0.0.0", port=7860)