SarahXia0405 commited on
Commit
73b3683
·
verified ·
1 Parent(s): 043c5ec

Update api/clare_core.py

Browse files
Files changed (1) hide show
  1. api/clare_core.py +132 -133
api/clare_core.py CHANGED
@@ -52,10 +52,12 @@ def _count_text_tokens(text: str, model: str = "") -> int:
52
 
53
 
54
  def _count_messages_tokens(messages: List[Dict[str, str]], model: str = "") -> int:
55
- # engineering approximation for chat messages overhead
 
 
56
  total = 0
57
  for m in messages or []:
58
- total += 4
59
  total += _count_text_tokens(str(m.get("role", "")), model=model)
60
  total += _count_text_tokens(str(m.get("content", "")), model=model)
61
  total += 2
@@ -89,6 +91,10 @@ def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
89
  return enc.decode(ids[:max_tokens])
90
 
91
 
 
 
 
 
92
  # ----------------------------
93
  # Speed knobs (HARD LIMITS)
94
  # ----------------------------
@@ -404,7 +410,8 @@ def _log_prompt_token_breakdown(
404
  "system_tokens": _count_text_tokens(system_prompt, model=model_name),
405
  "rag_tokens": _count_text_tokens(rag_context or "", model=model_name),
406
  "history_tokens": sum(
407
- _count_text_tokens(u or "", model=model_name) + _count_text_tokens(a or "", model=model_name)
 
408
  for u, a in (trimmed_history or [])
409
  ),
410
  "user_tokens": _count_text_tokens(user_message or "", model=model_name),
@@ -412,7 +419,7 @@ def _log_prompt_token_breakdown(
412
  "history_turns_kept": len(trimmed_history or []),
413
  "max_rag_tokens_in_prompt": MAX_RAG_TOKENS_IN_PROMPT,
414
  "max_output_tokens": DEFAULT_MAX_OUTPUT_TOKENS,
415
- "model": model_name or DEFAULT_MODEL,
416
  }
417
  print("[LLM_PROMPT_TOKENS] " + json.dumps(stats, ensure_ascii=False))
418
  return stats
@@ -426,118 +433,105 @@ def safe_chat_completion_profiled(
426
  op: str = "chat",
427
  temperature: float = 0.5,
428
  max_tokens: Optional[int] = None,
 
429
  ) -> Tuple[str, Dict]:
430
  """
431
- Returns:
432
- - answer text
433
- - profiling dict {ttft_ms, llm_total_ms, gen_ms, output_tokens_est, tokens_per_sec_est, streaming_used}
 
434
  """
435
- preferred_model = model_name or DEFAULT_MODEL
 
 
436
  max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
437
 
 
438
  last_error: Optional[Exception] = None
439
 
440
  for attempt in range(2):
441
- current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
442
-
443
- # 1) Try streaming for real TTFT
444
- t0 = time.perf_counter()
445
  try:
446
- resp = client.chat.completions.create(
447
- model=current_model,
 
 
 
 
448
  messages=messages,
449
  temperature=temperature,
450
  max_tokens=max_tokens,
451
- timeout=20,
452
  stream=True,
 
453
  )
454
 
455
- first_token_t = None
456
- out_parts: List[str] = []
457
- for event in resp:
458
- # OpenAI-style: event.choices[0].delta.content
 
459
  try:
460
- delta = event.choices[0].delta.content # type: ignore
461
  except Exception:
462
  delta = None
463
- if not delta:
464
- continue
465
- if first_token_t is None:
466
- first_token_t = time.perf_counter()
467
- out_parts.append(delta)
468
-
469
- t_end = time.perf_counter()
470
- answer = "".join(out_parts)
471
-
472
- ttft_ms = None if first_token_t is None else (first_token_t - t0) * 1000.0
473
- total_ms = (t_end - t0) * 1000.0
474
- gen_ms = None if first_token_t is None else (t_end - first_token_t) * 1000.0
475
- out_tokens = _count_text_tokens(answer, model=current_model)
476
- tokens_per_sec = None
477
- if gen_ms and gen_ms > 0:
478
- tokens_per_sec = out_tokens / (gen_ms / 1000.0)
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  prof = {
 
481
  "streaming_used": True,
 
 
 
482
  "ttft_ms": ttft_ms,
483
- "llm_total_ms": total_ms,
484
  "gen_ms": gen_ms,
485
- "output_tokens_est": out_tokens,
486
- "tokens_per_sec_est": tokens_per_sec,
487
- "model": current_model,
488
- "max_tokens": max_tokens,
489
  }
490
- print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
491
- return answer, prof
492
 
493
  except Exception as e:
494
  last_error = e
495
- # fall through to non-stream fallback below
496
-
497
- # 2) Non-stream fallback (TTFT not available; approximate)
498
- try:
499
- t0 = time.perf_counter()
500
- resp2 = client.chat.completions.create(
501
- model=current_model,
502
- messages=messages,
503
- temperature=temperature,
504
- max_tokens=max_tokens,
505
- timeout=20,
506
- )
507
- t_end = time.perf_counter()
508
- answer = resp2.choices[0].message.content or ""
509
-
510
- total_ms = (t_end - t0) * 1000.0
511
- out_tokens = _count_text_tokens(answer, model=current_model)
512
- tokens_per_sec = None
513
- if total_ms > 0:
514
- tokens_per_sec = out_tokens / (total_ms / 1000.0)
515
-
516
- prof = {
517
- "streaming_used": False,
518
- "ttft_ms": None, # not measurable without stream
519
- "llm_total_ms": total_ms,
520
- "gen_ms": None,
521
- "output_tokens_est": out_tokens,
522
- "tokens_per_sec_est": tokens_per_sec,
523
- "model": current_model,
524
- "max_tokens": max_tokens,
525
- "note": "non-stream fallback; ttft_ms unavailable",
526
- }
527
- print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
528
- return answer, prof
529
-
530
- except Exception as e:
531
  print(
532
- f"[safe_chat_completion_profiled][{op}] attempt {attempt+1} "
533
- f"failed with model={current_model}: {repr(e)}"
534
  )
535
- last_error = e
536
- if current_model == DEFAULT_MODEL or attempt == 1:
537
  break
538
 
539
- return build_error_message(last_error or Exception("unknown error"), lang, op), {
540
- "streaming_used": False,
 
 
 
 
 
 
 
541
  "error": repr(last_error) if last_error else "unknown",
542
  }
543
 
@@ -552,7 +546,10 @@ def build_messages(
552
  weaknesses: Optional[List[str]],
553
  cognitive_state: Optional[Dict[str, int]],
554
  rag_context: Optional[str] = None,
 
555
  ) -> List[Dict[str, str]]:
 
 
556
  messages: List[Dict[str, str]] = [
557
  {"role": "system", "content": CLARE_SYSTEM_PROMPT}
558
  ]
@@ -593,9 +590,7 @@ def build_messages(
593
  messages.append(
594
  {
595
  "role": "system",
596
- "content": (
597
- "Student struggles (recent). Be extra clear on these: " + weak_text
598
- ),
599
  }
600
  )
601
 
@@ -606,18 +601,14 @@ def build_messages(
606
  messages.append(
607
  {
608
  "role": "system",
609
- "content": (
610
- "Student under HIGH cognitive load. Use simpler language and shorter steps."
611
- ),
612
  }
613
  )
614
  elif mastery >= 2 and mastery >= confusion + 1:
615
  messages.append(
616
  {
617
  "role": "system",
618
- "content": (
619
- "Student comfortable. You may go slightly deeper and add a follow-up question."
620
- ),
621
  }
622
  )
623
 
@@ -631,25 +622,20 @@ def build_messages(
631
  weaknesses=weaknesses,
632
  cognitive_state=cognitive_state,
633
  )
634
- messages.append(
635
- {
636
- "role": "system",
637
- "content": "Session memory: " + session_memory_text,
638
- }
639
- )
640
 
641
- # RAG context: enforce token cap here (in addition to rag_engine caps)
642
  rag_text_for_prompt = ""
643
  if rag_context:
644
  rag_text_for_prompt = _truncate_to_tokens(
645
- rag_context, max_tokens=MAX_RAG_TOKENS_IN_PROMPT, model=model_name_or_default(DEFAULT_MODEL)
 
 
646
  )
647
  messages.append(
648
  {
649
  "role": "system",
650
- "content": (
651
- "Relevant excerpts (use as primary grounding):\n\n" + rag_text_for_prompt
652
- ),
653
  }
654
  )
655
 
@@ -669,16 +655,12 @@ def build_messages(
669
  rag_context=rag_text_for_prompt,
670
  trimmed_history=trimmed_history,
671
  user_message=user_message,
672
- model_name=(DEFAULT_MODEL or ""),
673
  )
674
 
675
  return messages
676
 
677
 
678
- def model_name_or_default(x: str) -> str:
679
- return x or DEFAULT_MODEL
680
-
681
-
682
  @traceable(run_type="chain", name="chat_with_clare")
683
  def chat_with_clare(
684
  message: str,
@@ -691,7 +673,13 @@ def chat_with_clare(
691
  weaknesses: Optional[List[str]],
692
  cognitive_state: Optional[Dict[str, int]],
693
  rag_context: Optional[str] = None,
694
- ) -> Tuple[str, List[Tuple[str, str]]]:
 
 
 
 
 
 
695
  try:
696
  set_run_metadata(
697
  learning_mode=learning_mode,
@@ -711,29 +699,41 @@ def chat_with_clare(
711
  weaknesses=weaknesses,
712
  cognitive_state=cognitive_state,
713
  rag_context=rag_context,
 
714
  )
715
- answer, prof = safe_chat_completion_profiled(...)
 
 
 
 
 
 
 
 
 
 
 
716
  history = history + [(message, answer)]
717
-
718
  llm_stats = {
719
- "llm_profile": {
720
- "model": prof.get("model"),
721
- "streaming_used": prof.get("streaming_used"),
722
- "max_tokens": prof.get("max_tokens"),
723
- "output_tokens_est": prof.get("output_tokens_est"),
724
- "tokens_per_sec_est": prof.get("tokens_per_sec_est"),
725
- },
726
- "marks_ms": {
727
- "llm_first_token": prof.get("ttft_ms"),
728
- "llm_done": prof.get("llm_total_ms"),
729
- },
730
- "segments_ms": {
731
- "llm_ttft_ms": prof.get("ttft_ms"),
732
- "llm_gen_ms": prof.get("gen_ms"),
733
- "llm_done": prof.get("llm_total_ms"),
734
- },
735
  }
736
-
737
  return answer, history, llm_stats
738
 
739
 
@@ -785,9 +785,7 @@ def summarize_conversation(
785
  {"role": "system", "content": CLARE_SYSTEM_PROMPT},
786
  {
787
  "role": "system",
788
- "content": (
789
- "Produce a concept-only summary. Use bullet points. No off-topic text."
790
- ),
791
  },
792
  {"role": "system", "content": f"Course topics: {topics_text}"},
793
  {"role": "system", "content": f"Student difficulties: {weakness_text}"},
@@ -805,5 +803,6 @@ def summarize_conversation(
805
  op="summary",
806
  temperature=0.4,
807
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
 
808
  )
809
  return summary_text
 
52
 
53
 
54
  def _count_messages_tokens(messages: List[Dict[str, str]], model: str = "") -> int:
55
+ """
56
+ Approximation for chat messages overhead.
57
+ """
58
  total = 0
59
  for m in messages or []:
60
+ total += 4 # role/content wrappers
61
  total += _count_text_tokens(str(m.get("role", "")), model=model)
62
  total += _count_text_tokens(str(m.get("content", "")), model=model)
63
  total += 2
 
91
  return enc.decode(ids[:max_tokens])
92
 
93
 
94
+ def model_name_or_default(x: str) -> str:
95
+ return (x or "").strip() or DEFAULT_MODEL
96
+
97
+
98
  # ----------------------------
99
  # Speed knobs (HARD LIMITS)
100
  # ----------------------------
 
410
  "system_tokens": _count_text_tokens(system_prompt, model=model_name),
411
  "rag_tokens": _count_text_tokens(rag_context or "", model=model_name),
412
  "history_tokens": sum(
413
+ _count_text_tokens(u or "", model=model_name)
414
+ + _count_text_tokens(a or "", model=model_name)
415
  for u, a in (trimmed_history or [])
416
  ),
417
  "user_tokens": _count_text_tokens(user_message or "", model=model_name),
 
419
  "history_turns_kept": len(trimmed_history or []),
420
  "max_rag_tokens_in_prompt": MAX_RAG_TOKENS_IN_PROMPT,
421
  "max_output_tokens": DEFAULT_MAX_OUTPUT_TOKENS,
422
+ "model": model_name,
423
  }
424
  print("[LLM_PROMPT_TOKENS] " + json.dumps(stats, ensure_ascii=False))
425
  return stats
 
433
  op: str = "chat",
434
  temperature: float = 0.5,
435
  max_tokens: Optional[int] = None,
436
+ timeout: int = 20,
437
  ) -> Tuple[str, Dict]:
438
  """
439
+ Streaming-based call to measure TTFT and tokens/sec (estimated).
440
+ Returns: (text, prof)
441
+ prof includes:
442
+ model, llm_total_ms, ttft_ms, gen_ms, output_tokens_est, tokens_per_sec_est, streaming_used, max_tokens
443
  """
444
+ t0 = time.perf_counter()
445
+
446
+ preferred_model = model_name_or_default(model_name)
447
  max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
448
 
449
+ used_model = preferred_model
450
  last_error: Optional[Exception] = None
451
 
452
  for attempt in range(2):
453
+ used_model = preferred_model if attempt == 0 else DEFAULT_MODEL
 
 
 
454
  try:
455
+ first_token_ms: Optional[float] = None
456
+ text_parts: List[str] = []
457
+ output_chars = 0
458
+
459
+ stream = client.chat.completions.create(
460
+ model=used_model,
461
  messages=messages,
462
  temperature=temperature,
463
  max_tokens=max_tokens,
 
464
  stream=True,
465
+ timeout=timeout,
466
  )
467
 
468
+ for chunk in stream:
469
+ if first_token_ms is None:
470
+ first_token_ms = (time.perf_counter() - t0) * 1000.0
471
+
472
+ delta = None
473
  try:
474
+ delta = chunk.choices[0].delta
475
  except Exception:
476
  delta = None
477
+
478
+ piece = ""
479
+ if delta is not None:
480
+ piece = getattr(delta, "content", "") or ""
481
+ else:
482
+ try:
483
+ piece = chunk.choices[0].message.content or ""
484
+ except Exception:
485
+ piece = ""
486
+
487
+ if piece:
488
+ text_parts.append(piece)
489
+ output_chars += len(piece)
490
+
491
+ full_text = "".join(text_parts)
492
+ llm_total_ms = (time.perf_counter() - t0) * 1000.0
493
+ ttft_ms = float(first_token_ms or llm_total_ms)
494
+ gen_ms = max(0.0, llm_total_ms - ttft_ms)
495
+
496
+ # output tokens est (rough)
497
+ if re.search(r"[\u4e00-\u9fff]", full_text or ""):
498
+ output_tokens_est = int(output_chars / 2.0) if output_chars else 0
499
+ else:
500
+ output_tokens_est = int(output_chars / 4.0) if output_chars else 0
501
+
502
+ tokens_per_sec_est = (
503
+ (output_tokens_est / (gen_ms / 1000.0)) if gen_ms > 1 else None
504
+ )
505
 
506
  prof = {
507
+ "model": used_model,
508
  "streaming_used": True,
509
+ "max_tokens": max_tokens,
510
+ "output_tokens_est": output_tokens_est,
511
+ "tokens_per_sec_est": tokens_per_sec_est,
512
  "ttft_ms": ttft_ms,
 
513
  "gen_ms": gen_ms,
514
+ "llm_total_ms": llm_total_ms,
 
 
 
515
  }
516
+ return full_text, prof
 
517
 
518
  except Exception as e:
519
  last_error = e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  print(
521
+ f"[safe_chat_completion_profiled][{op}] attempt {attempt+1} failed: {repr(e)}"
 
522
  )
523
+ if attempt == 1:
 
524
  break
525
 
526
+ return build_error_message(last_error or Exception("unknown"), lang, op), {
527
+ "model": used_model,
528
+ "streaming_used": True,
529
+ "max_tokens": max_tokens,
530
+ "output_tokens_est": 0,
531
+ "tokens_per_sec_est": None,
532
+ "ttft_ms": None,
533
+ "gen_ms": None,
534
+ "llm_total_ms": (time.perf_counter() - t0) * 1000.0,
535
  "error": repr(last_error) if last_error else "unknown",
536
  }
537
 
 
546
  weaknesses: Optional[List[str]],
547
  cognitive_state: Optional[Dict[str, int]],
548
  rag_context: Optional[str] = None,
549
+ model_name: str = "",
550
  ) -> List[Dict[str, str]]:
551
+ model_for_count = model_name_or_default(model_name)
552
+
553
  messages: List[Dict[str, str]] = [
554
  {"role": "system", "content": CLARE_SYSTEM_PROMPT}
555
  ]
 
590
  messages.append(
591
  {
592
  "role": "system",
593
+ "content": "Student struggles (recent). Be extra clear on these: " + weak_text,
 
 
594
  }
595
  )
596
 
 
601
  messages.append(
602
  {
603
  "role": "system",
604
+ "content": "Student under HIGH cognitive load. Use simpler language and shorter steps.",
 
 
605
  }
606
  )
607
  elif mastery >= 2 and mastery >= confusion + 1:
608
  messages.append(
609
  {
610
  "role": "system",
611
+ "content": "Student comfortable. You may go slightly deeper and add a follow-up question.",
 
 
612
  }
613
  )
614
 
 
622
  weaknesses=weaknesses,
623
  cognitive_state=cognitive_state,
624
  )
625
+ messages.append({"role": "system", "content": "Session memory: " + session_memory_text})
 
 
 
 
 
626
 
627
+ # RAG context: enforce token cap here
628
  rag_text_for_prompt = ""
629
  if rag_context:
630
  rag_text_for_prompt = _truncate_to_tokens(
631
+ rag_context,
632
+ max_tokens=MAX_RAG_TOKENS_IN_PROMPT,
633
+ model=model_for_count,
634
  )
635
  messages.append(
636
  {
637
  "role": "system",
638
+ "content": "Relevant excerpts (use as primary grounding):\n\n" + rag_text_for_prompt,
 
 
639
  }
640
  )
641
 
 
655
  rag_context=rag_text_for_prompt,
656
  trimmed_history=trimmed_history,
657
  user_message=user_message,
658
+ model_name=model_for_count,
659
  )
660
 
661
  return messages
662
 
663
 
 
 
 
 
664
  @traceable(run_type="chain", name="chat_with_clare")
665
  def chat_with_clare(
666
  message: str,
 
673
  weaknesses: Optional[List[str]],
674
  cognitive_state: Optional[Dict[str, int]],
675
  rag_context: Optional[str] = None,
676
+ ) -> Tuple[str, List[Tuple[str, str]], Dict]:
677
+ """
678
+ Returns:
679
+ answer: str
680
+ history: List[(user, assistant)]
681
+ llm_stats: Dict (TTFT + tokens/sec est + prompt token breakdown printed in logs)
682
+ """
683
  try:
684
  set_run_metadata(
685
  learning_mode=learning_mode,
 
699
  weaknesses=weaknesses,
700
  cognitive_state=cognitive_state,
701
  rag_context=rag_context,
702
+ model_name=model_name,
703
  )
704
+
705
+ # IMPORTANT: pass messages + lang (fixes your HTTP 500)
706
+ answer, prof = safe_chat_completion_profiled(
707
+ model_name=model_name,
708
+ messages=messages,
709
+ lang=language_preference,
710
+ op="chat",
711
+ temperature=0.5,
712
+ max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
713
+ timeout=20,
714
+ )
715
+
716
  history = history + [(message, answer)]
717
+
718
  llm_stats = {
719
+ "llm_profile": {
720
+ "model": prof.get("model"),
721
+ "streaming_used": prof.get("streaming_used"),
722
+ "max_tokens": prof.get("max_tokens"),
723
+ "output_tokens_est": prof.get("output_tokens_est"),
724
+ "tokens_per_sec_est": prof.get("tokens_per_sec_est"),
725
+ },
726
+ "marks_ms": {
727
+ "llm_first_token": prof.get("ttft_ms"),
728
+ "llm_done": prof.get("llm_total_ms"),
729
+ },
730
+ "segments_ms": {
731
+ "llm_ttft_ms": prof.get("ttft_ms"),
732
+ "llm_gen_ms": prof.get("gen_ms"),
733
+ "llm_done": prof.get("llm_total_ms"),
734
+ },
735
  }
736
+
737
  return answer, history, llm_stats
738
 
739
 
 
785
  {"role": "system", "content": CLARE_SYSTEM_PROMPT},
786
  {
787
  "role": "system",
788
+ "content": "Produce a concept-only summary. Use bullet points. No off-topic text.",
 
 
789
  },
790
  {"role": "system", "content": f"Course topics: {topics_text}"},
791
  {"role": "system", "content": f"Student difficulties: {weakness_text}"},
 
803
  op="summary",
804
  temperature=0.4,
805
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
806
+ timeout=20,
807
  )
808
  return summary_text