rairo commited on
Commit
2efb4c9
·
verified ·
1 Parent(s): 34b3eb4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +625 -125
main.py CHANGED
@@ -1,5 +1,5 @@
1
  # ============================================================
2
- # main.py — AI Partner Cultural Simulator Backend
3
  # REST + WebSocket | Gemini + Firebase + Azure Pronunciation
4
  # HuggingFace-compatible (Port 7860)
5
  # ============================================================
@@ -39,90 +39,147 @@ CORS(app)
39
  socketio = SocketIO(app, cors_allowed_origins="*", async_mode="eventlet")
40
 
41
  # ============================================================
42
- # ENV & CLIENT INITIALIZATION
43
  # ============================================================
44
 
45
- # --- Firebase ---
46
- firebase_json = os.environ.get("FIREBASE")
47
- firebase_db_url = os.environ.get("Firebase_DB")
48
-
49
- cred = credentials.Certificate(json.loads(firebase_json))
50
- firebase_admin.initialize_app(cred, {"databaseURL": firebase_db_url})
51
- db_ref = db.reference()
52
-
53
- # --- Gemini ---
54
- GEMINI_API_KEY = os.environ.get("Gemini")
55
- MODEL_NAME = "gemini-2.0-flash"
56
- gemini_client = genai.Client(api_key=GEMINI_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # --- Azure Speech ---
59
  AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_KEY")
60
  AZURE_SPEECH_REGION = os.environ.get("AZURE_SPEECH_REGION")
61
 
62
  # ============================================================
63
- # LANGUAGE PACKS
64
  # ============================================================
65
 
66
  from korean import KOREAN_PACK
67
  from english import ENGLISH_PACK
68
- # from japanese import JAPANESE_PACK
69
- # from german import GERMAN_PACK
70
 
71
  LANGUAGE_PACKS = {
72
  "ko-KR": KOREAN_PACK,
73
  "en-US": ENGLISH_PACK,
74
- # "ja-JP": JAPANESE_PACK,
75
- # "de-DE": GERMAN_PACK,
76
  }
77
 
78
  # ============================================================
79
- # AUTH HELPERS (reuse-friendly)
80
  # ============================================================
81
 
 
 
 
82
  def verify_token(auth_header):
83
  if not auth_header or not auth_header.startswith("Bearer "):
84
  return None
 
85
  try:
86
- token = auth_header.split("Bearer ")[1]
87
  return auth.verify_id_token(token)["uid"]
88
- except Exception:
 
89
  return None
90
 
 
 
 
 
 
 
 
 
 
91
  def require_user():
92
  uid = verify_token(request.headers.get("Authorization"))
93
  if not uid:
94
  raise PermissionError("Unauthorized")
95
  return uid
96
 
 
 
 
 
 
 
97
  # ============================================================
98
- # CREDIT RULES (simple + explicit)
99
  # ============================================================
100
 
101
  START_SESSION_COST = 1
102
  PRACTICE_ATTEMPT_COST = 1
103
  PER_MINUTE_COST = 2
104
 
105
- def charge(uid, amount):
106
  user_ref = db_ref.child(f"users/{uid}")
107
- user = user_ref.get()
108
- if user.get("credits", 0) < amount:
 
 
109
  raise ValueError("Insufficient credits")
110
- user_ref.update({"credits": user["credits"] - amount})
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # ============================================================
113
- # SESSION HELPERS
114
  # ============================================================
115
 
116
  def create_session(uid, language, scenario_id, title):
117
  session_id = str(uuid.uuid4())
118
  session = {
119
  "sessionId": session_id,
 
120
  "language": language,
121
  "scenarioId": scenario_id,
122
  "title": title,
123
  "meters": {"respect": 50, "influence": 50, "trust": 50},
124
  "turns": [],
125
- "createdAt": datetime.utcnow().isoformat() + "Z"
 
 
 
126
  }
127
  db_ref.child(f"sessions/{uid}/{session_id}").set(session)
128
  return session
@@ -130,52 +187,66 @@ def create_session(uid, language, scenario_id, title):
130
  def get_session(uid, session_id):
131
  return db_ref.child(f"sessions/{uid}/{session_id}").get()
132
 
133
- def update_session(uid, session_id, data):
134
- db_ref.child(f"sessions/{uid}/{session_id}").update(data)
 
 
 
 
 
 
 
135
 
136
  # ============================================================
137
- # GEMINI — CULTURAL EVALUATOR
138
  # ============================================================
139
 
140
- def evaluate_turn(language_pack, scenario, transcript_turn):
 
 
 
 
 
 
 
141
  prompt = f"""
142
- You are a cultural authority evaluator.
143
 
 
144
  LANGUAGE: {language_pack["language"]}
145
  SCENARIO: {scenario["name"]}
146
- EXPECTATIONS:
147
- {json.dumps(scenario["rules"], indent=2)}
 
148
 
149
  USER SAID:
150
  "{transcript_turn}"
151
 
152
- Return STRICT JSON:
153
  {{
154
  "meter_delta": {{
155
- "respect": <int>,
156
- "influence": <int>,
157
- "trust": <int>
158
  }},
159
- "feedback": "<short cultural feedback>",
160
  "checkpoint_required": <true|false>
161
  }}
162
  """
 
 
 
 
 
163
  try:
164
- response = gemini_client.models.generate_content(
165
- model=MODEL_NAME,
166
- contents=prompt
167
- )
168
- return json.loads(response.text.strip().lstrip("```json").rstrip("```"))
169
  except Exception as e:
170
- logger.error(f"Gemini eval failed: {e}")
171
- return {
172
- "meter_delta": {"respect": 0, "influence": 0, "trust": 0},
173
- "feedback": "Evaluation unavailable.",
174
- "checkpoint_required": False
175
- }
176
 
177
  # ============================================================
178
- # AUDIO SANITIZER (Azure requirement)
179
  # ============================================================
180
 
181
  def sanitize_audio(raw_path):
@@ -191,104 +262,547 @@ def sanitize_audio(raw_path):
191
  subprocess.run(cmd, check=True)
192
  return clean_path
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # ============================================================
195
- # REST ENDPOINTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # ============================================================
197
 
198
  @app.route("/api/session/start", methods=["POST"])
199
  def start_session():
200
  try:
201
  uid = require_user()
202
- data = request.get_json()
203
 
204
- language = data["language"]
205
- scenario_id = data["scenarioId"]
 
 
 
 
 
206
 
207
  pack = LANGUAGE_PACKS[language]
208
- scenario = pack["scenarios"][scenario_id]
 
 
 
209
  title = scenario["title"]
210
 
211
- charge(uid, START_SESSION_COST)
212
 
213
  session = create_session(uid, language, scenario_id, title)
214
 
 
 
 
 
 
 
 
 
215
  return jsonify({
216
  "session": session,
217
- "dynamicVariables": {
218
- "title": title,
219
- "language": pack["language"],
220
- "scenarioName": scenario["name"]
221
- }
222
- })
223
 
 
 
224
  except Exception as e:
 
 
225
  return jsonify({"error": str(e)}), 400
226
 
227
  @app.route("/api/session/turn", methods=["POST"])
228
  def submit_turn():
229
  try:
230
  uid = require_user()
231
- data = request.get_json()
232
 
233
- session_id = data["sessionId"]
234
- transcript = data["transcript"]
 
 
235
 
236
  session = get_session(uid, session_id)
 
 
 
 
 
237
  pack = LANGUAGE_PACKS[session["language"]]
238
  scenario = pack["scenarios"][session["scenarioId"]]
 
239
 
240
- result = evaluate_turn(pack, scenario, transcript)
 
 
 
241
 
242
- meters = session["meters"]
243
  for k in meters:
244
- meters[k] = max(0, min(100, meters[k] + result["meter_delta"][k]))
245
 
246
- session["turns"].append({
 
 
 
247
  "text": transcript,
248
- "feedback": result["feedback"]
249
  })
250
 
251
  update_session(uid, session_id, {
252
  "meters": meters,
253
- "turns": session["turns"]
254
  })
255
 
256
  return jsonify({
257
  "meters": meters,
258
- "feedback": result["feedback"],
259
- "checkpointRequired": result["checkpoint_required"]
260
- })
261
 
 
 
262
  except Exception as e:
 
 
263
  return jsonify({"error": str(e)}), 400
264
 
265
  @app.route("/api/session/end", methods=["POST"])
266
  def end_session():
267
  try:
268
  uid = require_user()
269
- data = request.get_json()
270
- session_id = data["sessionId"]
271
- duration = data["durationSeconds"]
 
 
 
 
 
 
 
 
272
 
273
- cost = math.ceil(duration / 60) * PER_MINUTE_COST
274
- charge(uid, cost)
275
 
276
- return jsonify({"status": "completed", "cost": cost})
 
 
 
 
 
277
 
 
 
 
 
 
 
 
 
278
  except Exception as e:
 
 
279
  return jsonify({"error": str(e)}), 400
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # ============================================================
282
- # WEBSOCKET — AZURE PRONUNCIATION
283
  # ============================================================
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  @socketio.on("practice_pronunciation")
286
  def practice_pronunciation(data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  try:
288
- ref_text = data["text"]
 
 
289
  lang = data.get("lang", "en-US")
 
290
 
291
- audio_b64 = data["audio"].split(",")[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  audio_bytes = base64.b64decode(audio_b64)
293
 
294
  with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
@@ -297,49 +811,35 @@ def practice_pronunciation(data):
297
 
298
  clean_path = sanitize_audio(raw_path)
299
 
300
- speech_config = speechsdk.SpeechConfig(
301
- subscription=AZURE_SPEECH_KEY,
302
- region=AZURE_SPEECH_REGION
303
- )
304
- speech_config.speech_recognition_language = lang
305
-
306
- audio_config = speechsdk.audio.AudioConfig(filename=clean_path)
307
- recognizer = speechsdk.SpeechRecognizer(
308
- speech_config=speech_config,
309
- audio_config=audio_config
310
- )
311
-
312
- pa_config = speechsdk.PronunciationAssessmentConfig(
313
- reference_text=ref_text,
314
- grading_system=speechsdk.PronunciationAssessmentGradingSystem.HundredMark,
315
- granularity=speechsdk.PronunciationAssessmentGranularity.Word,
316
- enable_miscue=True
317
- )
318
- pa_config.apply_to(recognizer)
319
-
320
- result = recognizer.recognize_once_async().get()
321
- pa_result = speechsdk.PronunciationAssessmentResult(result)
322
-
323
- words = [
324
- {
325
- "word": w.word,
326
- "score": w.accuracy_score,
327
- "error": w.error_type
328
- }
329
- for w in pa_result.words
330
- ]
331
-
332
- emit("pronunciation_result", {
333
- "accuracy": pa_result.accuracy_score,
334
- "fluency": pa_result.fluency_score,
335
- "words": words
336
- })
337
 
338
  except Exception as e:
339
- emit("pronunciation_result", {"error": "Pronunciation failed"})
 
 
 
 
 
 
 
 
 
340
 
341
  # ============================================================
342
- # MAIN
343
  # ============================================================
344
 
345
  if __name__ == "__main__":
 
1
  # ============================================================
2
+ # main.py — AI Partner Cultural Simulator Backend (FULL)
3
  # REST + WebSocket | Gemini + Firebase + Azure Pronunciation
4
  # HuggingFace-compatible (Port 7860)
5
  # ============================================================
 
39
  socketio = SocketIO(app, cors_allowed_origins="*", async_mode="eventlet")
40
 
41
  # ============================================================
42
+ # 1) ENV & CLIENT INITIALIZATION (same naming as your template)
43
  # ============================================================
44
 
45
+ # --- Firebase Initialization ---
46
+ try:
47
+ credentials_json_string = os.environ.get("FIREBASE")
48
+ if not credentials_json_string:
49
+ raise ValueError("The 'FIREBASE' environment variable is not set.")
50
+ credentials_json = json.loads(credentials_json_string)
51
+
52
+ firebase_db_url = os.environ.get("Firebase_DB")
53
+ if not firebase_db_url:
54
+ raise ValueError("The 'Firebase_DB' environment variable must be set.")
55
+
56
+ cred = credentials.Certificate(credentials_json)
57
+ firebase_admin.initialize_app(cred, {"databaseURL": firebase_db_url})
58
+ db_ref = db.reference()
59
+ logger.info("Firebase Admin SDK initialized successfully.")
60
+ except Exception as e:
61
+ logger.critical(f"FATAL: Error initializing Firebase: {e}")
62
+ logger.critical(traceback.format_exc())
63
+ raise
64
+
65
+ # --- Gemini Initialization ---
66
+ try:
67
+ gemini_api_key = os.environ.get("Gemini")
68
+ if not gemini_api_key:
69
+ raise ValueError("The 'Gemini' environment variable is not set.")
70
+ client = genai.Client(api_key=gemini_api_key)
71
+ MODEL_NAME = "gemini-2.0-flash"
72
+ logger.info(f"Gemini client initialized for model: {MODEL_NAME}")
73
+ except Exception as e:
74
+ logger.critical(f"FATAL: Error initializing Gemini: {e}")
75
+ logger.critical(traceback.format_exc())
76
+ raise
77
 
78
  # --- Azure Speech ---
79
  AZURE_SPEECH_KEY = os.environ.get("AZURE_SPEECH_KEY")
80
  AZURE_SPEECH_REGION = os.environ.get("AZURE_SPEECH_REGION")
81
 
82
  # ============================================================
83
+ # 2) LANGUAGE PACKS
84
  # ============================================================
85
 
86
  from korean import KOREAN_PACK
87
  from english import ENGLISH_PACK
88
+ from japanese import JAPANESE_PACK
89
+ from german import GERMAN_PACK
90
 
91
  LANGUAGE_PACKS = {
92
  "ko-KR": KOREAN_PACK,
93
  "en-US": ENGLISH_PACK,
94
+ "ja-JP": JAPANESE_PACK,
95
+ "de-DE": GERMAN_PACK,
96
  }
97
 
98
  # ============================================================
99
+ # 3) CORE HELPERS (same style as your existing app)
100
  # ============================================================
101
 
102
+ def now_iso():
103
+ return datetime.utcnow().isoformat() + "Z"
104
+
105
  def verify_token(auth_header):
106
  if not auth_header or not auth_header.startswith("Bearer "):
107
  return None
108
+ token = auth_header.split("Bearer ")[1]
109
  try:
 
110
  return auth.verify_id_token(token)["uid"]
111
+ except Exception as e:
112
+ logger.warning(f"Token verification failed: {e}")
113
  return None
114
 
115
+ def verify_admin(auth_header):
116
+ uid = verify_token(auth_header)
117
+ if not uid:
118
+ raise PermissionError("Invalid or missing user token")
119
+ user_data = db_ref.child(f"users/{uid}").get()
120
+ if not user_data or not user_data.get("is_admin", False):
121
+ raise PermissionError("Admin access required")
122
+ return uid
123
+
124
  def require_user():
125
  uid = verify_token(request.headers.get("Authorization"))
126
  if not uid:
127
  raise PermissionError("Unauthorized")
128
  return uid
129
 
130
+ def get_user(uid):
131
+ return db_ref.child(f"users/{uid}").get()
132
+
133
+ def update_user(uid, payload: dict):
134
+ db_ref.child(f"users/{uid}").update(payload)
135
+
136
  # ============================================================
137
+ # 4) CREDITS (same vibe as Pitch Helper)
138
  # ============================================================
139
 
140
  START_SESSION_COST = 1
141
  PRACTICE_ATTEMPT_COST = 1
142
  PER_MINUTE_COST = 2
143
 
144
+ def charge(uid, amount, reason="charge"):
145
  user_ref = db_ref.child(f"users/{uid}")
146
+ user = user_ref.get() or {}
147
+ current = int(user.get("credits", 0))
148
+
149
+ if current < amount:
150
  raise ValueError("Insufficient credits")
151
+
152
+ new_total = max(0, current - int(amount))
153
+ user_ref.update({"credits": new_total})
154
+
155
+ # Optional: log credit usage (useful later)
156
+ db_ref.child(f"credit_ledger/{uid}").push().set({
157
+ "ts": now_iso(),
158
+ "delta": -int(amount),
159
+ "reason": reason,
160
+ "balance": new_total
161
+ })
162
+
163
+ return {"deducted": int(amount), "remaining": new_total}
164
 
165
  # ============================================================
166
+ # 5) SESSION HELPERS
167
  # ============================================================
168
 
169
  def create_session(uid, language, scenario_id, title):
170
  session_id = str(uuid.uuid4())
171
  session = {
172
  "sessionId": session_id,
173
+ "userId": uid,
174
  "language": language,
175
  "scenarioId": scenario_id,
176
  "title": title,
177
  "meters": {"respect": 50, "influence": 50, "trust": 50},
178
  "turns": [],
179
+ "createdAt": now_iso(),
180
+ "endedAt": None,
181
+ "status": "active",
182
+ "struggleWords": {} # rolling avg per word
183
  }
184
  db_ref.child(f"sessions/{uid}/{session_id}").set(session)
185
  return session
 
187
  def get_session(uid, session_id):
188
  return db_ref.child(f"sessions/{uid}/{session_id}").get()
189
 
190
+ def update_session(uid, session_id, payload: dict):
191
+ db_ref.child(f"sessions/{uid}/{session_id}").update(payload)
192
+
193
+ def list_sessions(uid):
194
+ data = db_ref.child(f"sessions/{uid}").get() or {}
195
+ # return as list sorted by createdAt desc
196
+ items = list(data.values())
197
+ items.sort(key=lambda x: x.get("createdAt", ""), reverse=True)
198
+ return items
199
 
200
  # ============================================================
201
+ # 6) GEMINI — CULTURAL TURN EVALUATOR (JSON output)
202
  # ============================================================
203
 
204
+ def _safe_json(text: str, fallback: dict):
205
+ try:
206
+ cleaned = (text or "").strip().lstrip("```json").rstrip("```").strip()
207
+ return json.loads(cleaned)
208
+ except Exception:
209
+ return fallback
210
+
211
+ def evaluate_turn(language_pack, scenario, transcript_turn, user_title):
212
  prompt = f"""
213
+ You are a cultural authority evaluator and business communication coach.
214
 
215
+ IMMERSION TITLE: {user_title}
216
  LANGUAGE: {language_pack["language"]}
217
  SCENARIO: {scenario["name"]}
218
+
219
+ EXPECTATIONS (rules):
220
+ {json.dumps(scenario["rules"], ensure_ascii=False, indent=2)}
221
 
222
  USER SAID:
223
  "{transcript_turn}"
224
 
225
+ Return STRICT JSON ONLY:
226
  {{
227
  "meter_delta": {{
228
+ "respect": <int between -10 and 10>,
229
+ "influence": <int between -10 and 10>,
230
+ "trust": <int between -10 and 10>
231
  }},
232
+ "feedback": "<one or two sentences, culturally grounded>",
233
  "checkpoint_required": <true|false>
234
  }}
235
  """
236
+ fallback = {
237
+ "meter_delta": {"respect": 0, "influence": 0, "trust": 0},
238
+ "feedback": "Evaluation unavailable.",
239
+ "checkpoint_required": False
240
+ }
241
  try:
242
+ resp = client.models.generate_content(model=MODEL_NAME, contents=prompt)
243
+ return _safe_json(resp.text, fallback)
 
 
 
244
  except Exception as e:
245
+ logger.error(f"Gemini evaluate_turn failed: {e}")
246
+ return fallback
 
 
 
 
247
 
248
  # ============================================================
249
+ # 7) AUDIO SANITIZER (Azure requirement)
250
  # ============================================================
251
 
252
  def sanitize_audio(raw_path):
 
262
  subprocess.run(cmd, check=True)
263
  return clean_path
264
 
265
+ def _azure_pronunciation_assess(reference_text, lang, wav_path):
266
+ speech_config = speechsdk.SpeechConfig(subscription=AZURE_SPEECH_KEY, region=AZURE_SPEECH_REGION)
267
+ speech_config.speech_recognition_language = lang
268
+ audio_config = speechsdk.audio.AudioConfig(filename=wav_path)
269
+
270
+ pronunciation_config = speechsdk.PronunciationAssessmentConfig(
271
+ reference_text=reference_text,
272
+ grading_system=speechsdk.PronunciationAssessmentGradingSystem.HundredMark,
273
+ granularity=speechsdk.PronunciationAssessmentGranularity.Word,
274
+ enable_miscue=True
275
+ )
276
+
277
+ recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
278
+ pronunciation_config.apply_to(recognizer)
279
+
280
+ result = recognizer.recognize_once_async().get()
281
+
282
+ if result.reason != speechsdk.ResultReason.RecognizedSpeech:
283
+ return {
284
+ "success": False,
285
+ "score": 0,
286
+ "fluency": 0,
287
+ "completeness": 0,
288
+ "recognized_text": getattr(result, "text", "") or "No match",
289
+ "word_details": []
290
+ }
291
+
292
+ pron_result = speechsdk.PronunciationAssessmentResult(result)
293
+
294
+ detailed_words = []
295
+ for w in pron_result.words:
296
+ detailed_words.append({
297
+ "word": w.word,
298
+ "score": w.accuracy_score,
299
+ "error": w.error_type
300
+ })
301
+
302
+ return {
303
+ "success": True,
304
+ "score": pron_result.accuracy_score,
305
+ "fluency": pron_result.fluency_score,
306
+ "completeness": pron_result.completeness_score,
307
+ "recognized_text": result.text,
308
+ "word_details": detailed_words
309
+ }
310
+
311
  # ============================================================
312
+ # 8) AUTH + PROFILE ENDPOINTS (ported from your template)
313
+ # ============================================================
314
+
315
+ @app.route("/api/auth/signup", methods=["POST"])
316
+ def signup():
317
+ try:
318
+ data = request.get_json() or {}
319
+ email = data.get("email")
320
+ password = data.get("password")
321
+ display_name = data.get("displayName")
322
+
323
+ if not email or not password:
324
+ return jsonify({"error": "Email and password are required"}), 400
325
+
326
+ user = auth.create_user(email=email, password=password, display_name=display_name)
327
+
328
+ user_data = {
329
+ "email": email,
330
+ "displayName": display_name,
331
+ "credits": 30,
332
+ "is_admin": False,
333
+ "createdAt": now_iso()
334
+ }
335
+ db_ref.child(f"users/{user.uid}").set(user_data)
336
+
337
+ return jsonify({"success": True, "uid": user.uid, **user_data}), 201
338
+
339
+ except Exception as e:
340
+ logger.error(f"Signup failed: {e}")
341
+ if "EMAIL_EXISTS" in str(e):
342
+ return jsonify({"error": "An account with this email already exists."}), 409
343
+ return jsonify({"error": str(e)}), 400
344
+
345
+ @app.route("/api/auth/social-signin", methods=["POST"])
346
+ def social_signin():
347
+ uid = verify_token(request.headers.get("Authorization"))
348
+ if not uid:
349
+ return jsonify({"error": "Invalid or expired token"}), 401
350
+
351
+ user_ref = db_ref.child(f"users/{uid}")
352
+ user_data = user_ref.get()
353
+
354
+ if user_data:
355
+ return jsonify({"uid": uid, **user_data}), 200
356
+
357
+ try:
358
+ firebase_user = auth.get_user(uid)
359
+ new_user_data = {
360
+ "email": firebase_user.email,
361
+ "displayName": firebase_user.display_name,
362
+ "credits": 30,
363
+ "is_admin": False,
364
+ "createdAt": now_iso()
365
+ }
366
+ user_ref.set(new_user_data)
367
+ return jsonify({"success": True, "uid": uid, **new_user_data}), 201
368
+ except Exception as e:
369
+ logger.error(f"Failed to create profile for social user {uid}: {e}")
370
+ return jsonify({"error": f"Failed to create user profile: {str(e)}"}), 500
371
+
372
+ @app.route("/api/user/profile", methods=["GET"])
373
+ def get_user_profile():
374
+ uid = verify_token(request.headers.get("Authorization"))
375
+ if not uid:
376
+ return jsonify({"error": "Invalid or expired token"}), 401
377
+
378
+ user_data = db_ref.child(f"users/{uid}").get()
379
+ if not user_data:
380
+ return jsonify({"error": "User not found"}), 404
381
+
382
+ return jsonify({"uid": uid, **user_data}), 200
383
+
384
+ @app.route("/api/user/profile", methods=["PATCH"])
385
+ def update_user_profile():
386
+ uid = verify_token(request.headers.get("Authorization"))
387
+ if not uid:
388
+ return jsonify({"error": "Invalid or expired token"}), 401
389
+
390
+ data = request.get_json() or {}
391
+ allowed = {}
392
+ # keep it simple + safe
393
+ if "displayName" in data and isinstance(data["displayName"], str):
394
+ allowed["displayName"] = data["displayName"].strip()
395
+ if "preferredLanguage" in data and isinstance(data["preferredLanguage"], str):
396
+ allowed["preferredLanguage"] = data["preferredLanguage"].strip()
397
+
398
+ if not allowed:
399
+ return jsonify({"error": "No valid fields to update"}), 400
400
+
401
+ update_user(uid, allowed)
402
+ user_data = get_user(uid) or {}
403
+ return jsonify({"success": True, "uid": uid, **user_data}), 200
404
+
405
+ # ============================================================
406
+ # 9) CREDITS + ADMIN ENDPOINTS (same as template)
407
+ # ============================================================
408
+
409
+ @app.route("/api/user/request-credits", methods=["POST"])
410
+ def request_credits():
411
+ uid = verify_token(request.headers.get("Authorization"))
412
+ if not uid:
413
+ return jsonify({"error": "Unauthorized"}), 401
414
+
415
+ try:
416
+ data = request.get_json() or {}
417
+ if "requested_credits" not in data:
418
+ return jsonify({"error": "requested_credits is required"}), 400
419
+
420
+ req_ref = db_ref.child("credit_requests").push()
421
+ req_ref.set({
422
+ "requestId": req_ref.key,
423
+ "userId": uid,
424
+ "requested_credits": int(data["requested_credits"]),
425
+ "status": "pending",
426
+ "requestedAt": now_iso()
427
+ })
428
+ return jsonify({"success": True, "requestId": req_ref.key}), 200
429
+
430
+ except Exception as e:
431
+ return jsonify({"error": str(e)}), 500
432
+
433
+ @app.route("/api/admin/credit_requests", methods=["GET"])
434
+ def list_credit_requests():
435
+ try:
436
+ verify_admin(request.headers.get("Authorization"))
437
+ requests_data = db_ref.child("credit_requests").get() or {}
438
+ return jsonify(list(requests_data.values())), 200
439
+ except PermissionError as e:
440
+ return jsonify({"error": str(e)}), 403
441
+ except Exception as e:
442
+ return jsonify({"error": str(e)}), 500
443
+
444
+ @app.route("/api/admin/credit_requests/<string:request_id>", methods=["PUT"])
445
+ def process_credit_request(request_id):
446
+ try:
447
+ admin_uid = verify_admin(request.headers.get("Authorization"))
448
+ req_ref = db_ref.child(f"credit_requests/{request_id}")
449
+ req_data = req_ref.get()
450
+
451
+ if not req_data:
452
+ return jsonify({"error": "Credit request not found"}), 404
453
+
454
+ decision = (request.get_json() or {}).get("decision")
455
+ if decision not in ["approved", "declined"]:
456
+ return jsonify({"error": 'Decision must be "approved" or "declined"'}), 400
457
+
458
+ if decision == "approved":
459
+ user_ref = db_ref.child(f"users/{req_data['userId']}")
460
+ user_data = user_ref.get() or {}
461
+ new_total = int(user_data.get("credits", 0)) + int(req_data.get("requested_credits", 0))
462
+ user_ref.update({"credits": new_total})
463
+
464
+ db_ref.child(f"credit_ledger/{req_data['userId']}").push().set({
465
+ "ts": now_iso(),
466
+ "delta": int(req_data.get("requested_credits", 0)),
467
+ "reason": "admin_credit_approval",
468
+ "balance": new_total,
469
+ "processedBy": admin_uid
470
+ })
471
+
472
+ req_ref.update({
473
+ "status": decision,
474
+ "processedBy": admin_uid,
475
+ "processedAt": now_iso()
476
+ })
477
+
478
+ return jsonify({"success": True, "message": f"Request {decision}."}), 200
479
+
480
+ except PermissionError as e:
481
+ return jsonify({"error": str(e)}), 403
482
+ except Exception as e:
483
+ return jsonify({"error": str(e)}), 500
484
+
485
+ @app.route("/api/admin/users/<string:uid>/credits", methods=["PUT"])
486
+ def admin_update_credits(uid):
487
+ try:
488
+ verify_admin(request.headers.get("Authorization"))
489
+ add_credits = (request.get_json() or {}).get("add_credits")
490
+ if add_credits is None:
491
+ return jsonify({"error": "add_credits is required"}), 400
492
+
493
+ user_ref = db_ref.child(f"users/{uid}")
494
+ user_data = user_ref.get()
495
+ if not user_data:
496
+ return jsonify({"error": "User not found"}), 404
497
+
498
+ new_total = int(user_data.get("credits", 0)) + int(add_credits)
499
+ user_ref.update({"credits": new_total})
500
+
501
+ db_ref.child(f"credit_ledger/{uid}").push().set({
502
+ "ts": now_iso(),
503
+ "delta": int(add_credits),
504
+ "reason": "admin_manual_adjust",
505
+ "balance": new_total
506
+ })
507
+
508
+ return jsonify({"success": True, "new_total_credits": new_total}), 200
509
+
510
+ except PermissionError as e:
511
+ return jsonify({"error": str(e)}), 403
512
+ except Exception as e:
513
+ return jsonify({"error": str(e)}), 500
514
+
515
+ # ============================================================
516
+ # 10) SESSION ENDPOINTS (your new app core)
517
  # ============================================================
518
 
519
  @app.route("/api/session/start", methods=["POST"])
520
  def start_session():
521
  try:
522
  uid = require_user()
523
+ data = request.get_json() or {}
524
 
525
+ language = data.get("language")
526
+ scenario_id = data.get("scenarioId")
527
+ if not language or not scenario_id:
528
+ return jsonify({"error": "language and scenarioId are required"}), 400
529
+
530
+ if language not in LANGUAGE_PACKS:
531
+ return jsonify({"error": f"Unsupported language: {language}"}), 400
532
 
533
  pack = LANGUAGE_PACKS[language]
534
+ scenario = pack["scenarios"].get(scenario_id)
535
+ if not scenario:
536
+ return jsonify({"error": "Invalid scenarioId"}), 400
537
+
538
  title = scenario["title"]
539
 
540
+ credit_info = charge(uid, START_SESSION_COST, reason="start_session")
541
 
542
  session = create_session(uid, language, scenario_id, title)
543
 
544
+ # Client passes these to ElevenLabs agent init as dynamic vars
545
+ dynamic_vars = {
546
+ "title": title,
547
+ "language": pack["language"],
548
+ "scenarioName": scenario["name"],
549
+ "scenarioId": scenario_id
550
+ }
551
+
552
  return jsonify({
553
  "session": session,
554
+ "dynamicVariables": dynamic_vars,
555
+ "credits": credit_info
556
+ }), 200
 
 
 
557
 
558
+ except PermissionError as e:
559
+ return jsonify({"error": str(e)}), 401
560
  except Exception as e:
561
+ logger.error(f"start_session failed: {e}")
562
+ logger.error(traceback.format_exc())
563
  return jsonify({"error": str(e)}), 400
564
 
565
  @app.route("/api/session/turn", methods=["POST"])
566
  def submit_turn():
567
  try:
568
  uid = require_user()
569
+ data = request.get_json() or {}
570
 
571
+ session_id = data.get("sessionId")
572
+ transcript = data.get("transcript")
573
+ if not session_id or not transcript:
574
+ return jsonify({"error": "sessionId and transcript are required"}), 400
575
 
576
  session = get_session(uid, session_id)
577
+ if not session:
578
+ return jsonify({"error": "Session not found"}), 404
579
+ if session.get("status") != "active":
580
+ return jsonify({"error": "Session is not active"}), 400
581
+
582
  pack = LANGUAGE_PACKS[session["language"]]
583
  scenario = pack["scenarios"][session["scenarioId"]]
584
+ title = session.get("title", scenario.get("title"))
585
 
586
+ result = evaluate_turn(pack, scenario, transcript, title)
587
+
588
+ meters = session.get("meters", {"respect": 50, "influence": 50, "trust": 50})
589
+ deltas = result.get("meter_delta", {"respect": 0, "influence": 0, "trust": 0})
590
 
 
591
  for k in meters:
592
+ meters[k] = max(0, min(100, int(meters[k]) + int(deltas.get(k, 0))))
593
 
594
+ turns = session.get("turns", [])
595
+ turns.append({
596
+ "id": str(uuid.uuid4()),
597
+ "at": now_iso(),
598
  "text": transcript,
599
+ "feedback": result.get("feedback", "")
600
  })
601
 
602
  update_session(uid, session_id, {
603
  "meters": meters,
604
+ "turns": turns
605
  })
606
 
607
  return jsonify({
608
  "meters": meters,
609
+ "feedback": result.get("feedback", ""),
610
+ "checkpointRequired": bool(result.get("checkpoint_required", False))
611
+ }), 200
612
 
613
+ except PermissionError as e:
614
+ return jsonify({"error": str(e)}), 401
615
  except Exception as e:
616
+ logger.error(f"submit_turn failed: {e}")
617
+ logger.error(traceback.format_exc())
618
  return jsonify({"error": str(e)}), 400
619
 
620
  @app.route("/api/session/end", methods=["POST"])
621
  def end_session():
622
  try:
623
  uid = require_user()
624
+ data = request.get_json() or {}
625
+
626
+ session_id = data.get("sessionId")
627
+ duration = data.get("durationSeconds")
628
+
629
+ if not session_id or not isinstance(duration, (int, float)):
630
+ return jsonify({"error": "sessionId and durationSeconds are required"}), 400
631
+
632
+ session = get_session(uid, session_id)
633
+ if not session:
634
+ return jsonify({"error": "Session not found"}), 404
635
 
636
+ cost = math.ceil(float(duration) / 60.0) * PER_MINUTE_COST
637
+ credit_info = charge(uid, cost, reason="session_minutes")
638
 
639
+ update_session(uid, session_id, {
640
+ "status": "completed",
641
+ "endedAt": now_iso(),
642
+ "durationSeconds": duration,
643
+ "minuteCost": cost
644
+ })
645
 
646
+ return jsonify({
647
+ "status": "completed",
648
+ "cost": cost,
649
+ "credits": credit_info
650
+ }), 200
651
+
652
+ except PermissionError as e:
653
+ return jsonify({"error": str(e)}), 401
654
  except Exception as e:
655
+ logger.error(f"end_session failed: {e}")
656
+ logger.error(traceback.format_exc())
657
  return jsonify({"error": str(e)}), 400
658
 
659
+ @app.route("/api/sessions", methods=["GET"])
660
+ def api_list_sessions():
661
+ try:
662
+ uid = require_user()
663
+ return jsonify(list_sessions(uid)), 200
664
+ except PermissionError as e:
665
+ return jsonify({"error": str(e)}), 401
666
+ except Exception as e:
667
+ return jsonify({"error": str(e)}), 500
668
+
669
+ @app.route("/api/sessions/<string:session_id>", methods=["GET"])
670
+ def api_get_session(session_id):
671
+ try:
672
+ uid = require_user()
673
+ s = get_session(uid, session_id)
674
+ if not s:
675
+ return jsonify({"error": "Session not found"}), 404
676
+ return jsonify(s), 200
677
+ except PermissionError as e:
678
+ return jsonify({"error": str(e)}), 401
679
+ except Exception as e:
680
+ return jsonify({"error": str(e)}), 500
681
+
682
  # ============================================================
683
+ # 11) WEBSOCKET — PRONUNCIATION (practice + live turn)
684
  # ============================================================
685
 
686
+ def _update_struggle_words(session, word_details):
687
+ """
688
+ Rolling average per word.
689
+ session["struggleWords"] shape:
690
+ { "word": {"avg": float, "count": int} }
691
+ """
692
+ struggle = session.get("struggleWords", {}) or {}
693
+ for wd in word_details:
694
+ w = (wd.get("word") or "").strip()
695
+ if not w:
696
+ continue
697
+ score = float(wd.get("score") or 0)
698
+ entry = struggle.get(w, {"avg": 0.0, "count": 0})
699
+ n = int(entry.get("count", 0))
700
+ avg = float(entry.get("avg", 0.0))
701
+ new_avg = (avg * n + score) / (n + 1)
702
+ struggle[w] = {"avg": new_avg, "count": n + 1}
703
+ return struggle
704
+
705
  @socketio.on("practice_pronunciation")
706
  def practice_pronunciation(data):
707
+ """
708
+ Practice loop: reference text must be supplied.
709
+ Optional:
710
+ - authToken: Firebase ID token (so we can charge credits)
711
+ - chargeCredits: true/false (default false)
712
+ """
713
+ raw_path = None
714
+ clean_path = None
715
+ try:
716
+ ref_text = data.get("text")
717
+ lang = data.get("lang", "en-US")
718
+ audio = data.get("audio")
719
+
720
+ if not ref_text or not audio:
721
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing text/audio"})
722
+ return
723
+
724
+ # Optional credit charge for practice
725
+ if data.get("chargeCredits", False):
726
+ auth_token = data.get("authToken")
727
+ if not auth_token:
728
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing authToken"})
729
+ return
730
+ uid = verify_token(f"Bearer {auth_token}")
731
+ if not uid:
732
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"})
733
+ return
734
+ try:
735
+ credit_info = charge(uid, PRACTICE_ATTEMPT_COST, reason="practice_attempt")
736
+ except Exception:
737
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Insufficient credits"})
738
+ return
739
+ else:
740
+ credit_info = None
741
+
742
+ # Decode base64 audio
743
+ audio_b64 = audio.split(",")[1] if "," in audio else audio
744
+ audio_bytes = base64.b64decode(audio_b64)
745
+
746
+ with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
747
+ f.write(audio_bytes)
748
+ raw_path = f.name
749
+
750
+ clean_path = sanitize_audio(raw_path)
751
+
752
+ result = _azure_pronunciation_assess(ref_text, lang, clean_path)
753
+ if credit_info:
754
+ result["credits"] = credit_info
755
+
756
+ emit("pronunciation_result", result)
757
+
758
+ except Exception as e:
759
+ logger.error(f"practice_pronunciation failed: {e}")
760
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"})
761
+ finally:
762
+ try:
763
+ if raw_path and os.path.exists(raw_path):
764
+ os.remove(raw_path)
765
+ if clean_path and os.path.exists(clean_path):
766
+ os.remove(clean_path)
767
+ except Exception:
768
+ pass
769
+
770
+ @socketio.on("live_pronunciation_turn")
771
+ def live_pronunciation_turn(data):
772
+ """
773
+ Live scoring for a session turn.
774
+ Requires:
775
+ - authToken (Firebase ID token)
776
+ - sessionId
777
+ - text (reference phrase OR checkpoint line)
778
+ - lang
779
+ - audio
780
+ Returns word_details + updated struggle words top list.
781
+ """
782
+ raw_path = None
783
+ clean_path = None
784
  try:
785
+ auth_token = data.get("authToken")
786
+ session_id = data.get("sessionId")
787
+ ref_text = data.get("text")
788
  lang = data.get("lang", "en-US")
789
+ audio = data.get("audio")
790
 
791
+ if not auth_token or not session_id or not ref_text or not audio:
792
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Missing fields"})
793
+ return
794
+
795
+ uid = verify_token(f"Bearer {auth_token}")
796
+ if not uid:
797
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Invalid authToken"})
798
+ return
799
+
800
+ session = get_session(uid, session_id)
801
+ if not session:
802
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Session not found"})
803
+ return
804
+
805
+ audio_b64 = audio.split(",")[1] if "," in audio else audio
806
  audio_bytes = base64.b64decode(audio_b64)
807
 
808
  with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
 
811
 
812
  clean_path = sanitize_audio(raw_path)
813
 
814
+ result = _azure_pronunciation_assess(ref_text, lang, clean_path)
815
+
816
+ # update rolling struggle words
817
+ struggle = _update_struggle_words(session, result.get("word_details", []))
818
+ update_session(uid, session_id, {"struggleWords": struggle})
819
+
820
+ # top 8 worst avg
821
+ top = sorted(
822
+ [{"word": w, "avg": v["avg"], "count": v["count"]} for w, v in struggle.items()],
823
+ key=lambda x: x["avg"]
824
+ )[:8]
825
+
826
+ result["struggle_top"] = top
827
+ emit("pronunciation_result", result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
 
829
  except Exception as e:
830
+ logger.error(f"live_pronunciation_turn failed: {e}")
831
+ emit("pronunciation_result", {"success": False, "score": 0, "recognized_text": "Server Error"})
832
+ finally:
833
+ try:
834
+ if raw_path and os.path.exists(raw_path):
835
+ os.remove(raw_path)
836
+ if clean_path and os.path.exists(clean_path):
837
+ os.remove(clean_path)
838
+ except Exception:
839
+ pass
840
 
841
  # ============================================================
842
+ # 12) MAIN
843
  # ============================================================
844
 
845
  if __name__ == "__main__":