SarahXia0405 commited on
Commit
1cf51d5
·
verified ·
1 Parent(s): c2aa07b

Update api/clare_core.py

Browse files
Files changed (1) hide show
  1. api/clare_core.py +29 -274
api/clare_core.py CHANGED
@@ -2,8 +2,6 @@
2
  import os
3
  import re
4
  import math
5
- import time
6
- import json
7
  from typing import List, Dict, Tuple, Optional
8
 
9
  from docx import Document
@@ -21,91 +19,11 @@ from langsmith import traceable
21
  from langsmith.run_helpers import set_run_metadata
22
 
23
 
24
- # ============================
25
- # Token helpers (optional tiktoken)
26
- # ============================
27
- def _safe_import_tiktoken():
28
- try:
29
- import tiktoken # type: ignore
30
- return tiktoken
31
- except Exception:
32
- return None
33
-
34
-
35
- def _approx_tokens(text: str) -> int:
36
- if not text:
37
- return 0
38
- return max(1, int(len(text) / 4))
39
-
40
-
41
- def _count_text_tokens(text: str, model: str = "") -> int:
42
- tk = _safe_import_tiktoken()
43
- if tk is None:
44
- return _approx_tokens(text)
45
-
46
- try:
47
- enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
48
- except Exception:
49
- enc = tk.get_encoding("cl100k_base")
50
-
51
- return len(enc.encode(text or ""))
52
-
53
-
54
- def _count_messages_tokens(messages: List[Dict[str, str]], model: str = "") -> int:
55
- """
56
- Approximation for chat messages overhead.
57
- """
58
- total = 0
59
- for m in messages or []:
60
- total += 4 # role/content wrappers
61
- total += _count_text_tokens(str(m.get("role", "")), model=model)
62
- total += _count_text_tokens(str(m.get("content", "")), model=model)
63
- total += 2
64
- return total
65
-
66
-
67
- def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
68
- if not text:
69
- return text
70
-
71
- tk = _safe_import_tiktoken()
72
- if tk is None:
73
- total = _approx_tokens(text)
74
- if total <= max_tokens:
75
- return text
76
- ratio = max_tokens / max(1, total)
77
- cut = max(50, min(len(text), int(len(text) * ratio)))
78
- s = text[:cut]
79
- while _approx_tokens(s) > max_tokens and len(s) > 50:
80
- s = s[: int(len(s) * 0.9)]
81
- return s
82
-
83
- try:
84
- enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
85
- except Exception:
86
- enc = tk.get_encoding("cl100k_base")
87
-
88
- ids = enc.encode(text or "")
89
- if len(ids) <= max_tokens:
90
- return text
91
- return enc.decode(ids[:max_tokens])
92
-
93
-
94
- def model_name_or_default(x: str) -> str:
95
- return (x or "").strip() or DEFAULT_MODEL
96
-
97
-
98
  # ----------------------------
99
- # Speed knobs (HARD LIMITS)
100
  # ----------------------------
101
- # 1) history 最近 10 轮
102
  MAX_HISTORY_TURNS = int(os.getenv("CLARE_MAX_HISTORY_TURNS", "10"))
103
-
104
- # 2) rag 最多 4 条每条 500 tokens 已在 rag_engine.py 实现
105
- # 这里仅控制“注入到 LLM prompt 的总 tokens”,避免 prompt 爆炸
106
- MAX_RAG_TOKENS_IN_PROMPT = int(os.getenv("CLARE_MAX_RAG_TOKENS", "2000"))
107
-
108
- # 3) max_new_tokens 默认 384
109
  DEFAULT_MAX_OUTPUT_TOKENS = int(os.getenv("CLARE_MAX_OUTPUT_TOKENS", "384"))
110
 
111
 
@@ -234,9 +152,7 @@ def build_session_memory_summary(
234
  parts.append("Cognitive state: " + describe_cognitive_state(cognitive_state))
235
 
236
  if not parts:
237
- return (
238
- "No prior session memory. Start with a short explanation and ask a quick check-up question."
239
- )
240
 
241
  return " | ".join(parts)
242
 
@@ -250,11 +166,7 @@ def detect_language(message: str, preference: str) -> str:
250
  return "English"
251
 
252
 
253
- def build_error_message(
254
- e: Exception,
255
- lang: str,
256
- op: str = "chat",
257
- ) -> str:
258
  if lang == "中文":
259
  prefix = {
260
  "chat": "抱歉,刚刚在和模型对话时出现了一点问题。",
@@ -398,142 +310,39 @@ def find_similar_past_question(
398
  return None
399
 
400
 
401
- def _log_prompt_token_breakdown(
402
- messages: List[Dict[str, str]],
403
- system_prompt: str,
404
- rag_context: str,
405
- trimmed_history: List[Tuple[str, str]],
406
- user_message: str,
407
- model_name: str,
408
- ):
409
- stats = {
410
- "system_tokens": _count_text_tokens(system_prompt, model=model_name),
411
- "rag_tokens": _count_text_tokens(rag_context or "", model=model_name),
412
- "history_tokens": sum(
413
- _count_text_tokens(u or "", model=model_name)
414
- + _count_text_tokens(a or "", model=model_name)
415
- for u, a in (trimmed_history or [])
416
- ),
417
- "user_tokens": _count_text_tokens(user_message or "", model=model_name),
418
- "prompt_tokens_total_est": _count_messages_tokens(messages, model=model_name),
419
- "history_turns_kept": len(trimmed_history or []),
420
- "max_rag_tokens_in_prompt": MAX_RAG_TOKENS_IN_PROMPT,
421
- "max_output_tokens": DEFAULT_MAX_OUTPUT_TOKENS,
422
- "model": model_name,
423
- }
424
- print("[LLM_PROMPT_TOKENS] " + json.dumps(stats, ensure_ascii=False))
425
- return stats
426
-
427
-
428
- @traceable(run_type="llm", name="safe_chat_completion_profiled")
429
- def safe_chat_completion_profiled(
430
  model_name: str,
431
  messages: List[Dict[str, str]],
432
  lang: str,
433
  op: str = "chat",
434
  temperature: float = 0.5,
435
  max_tokens: Optional[int] = None,
436
- timeout: int = 20,
437
- ) -> Tuple[str, Dict]:
438
- """
439
- Streaming-based call to measure TTFT and tokens/sec (estimated).
440
- Returns: (text, prof)
441
- prof includes:
442
- model, llm_total_ms, ttft_ms, gen_ms, output_tokens_est, tokens_per_sec_est, streaming_used, max_tokens
443
- """
444
- t0 = time.perf_counter()
445
-
446
  preferred_model = model_name_or_default(model_name)
447
- max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
448
-
449
- used_model = preferred_model
450
  last_error: Optional[Exception] = None
 
451
 
452
  for attempt in range(2):
453
- used_model = preferred_model if attempt == 0 else DEFAULT_MODEL
454
  try:
455
- first_token_ms: Optional[float] = None
456
- text_parts: List[str] = []
457
- output_chars = 0
458
-
459
- stream = client.chat.completions.create(
460
- model=used_model,
461
  messages=messages,
462
  temperature=temperature,
463
  max_tokens=max_tokens,
464
- stream=True,
465
- timeout=timeout,
466
  )
467
-
468
- for chunk in stream:
469
- if first_token_ms is None:
470
- first_token_ms = (time.perf_counter() - t0) * 1000.0
471
-
472
- delta = None
473
- try:
474
- delta = chunk.choices[0].delta
475
- except Exception:
476
- delta = None
477
-
478
- piece = ""
479
- if delta is not None:
480
- piece = getattr(delta, "content", "") or ""
481
- else:
482
- try:
483
- piece = chunk.choices[0].message.content or ""
484
- except Exception:
485
- piece = ""
486
-
487
- if piece:
488
- text_parts.append(piece)
489
- output_chars += len(piece)
490
-
491
- full_text = "".join(text_parts)
492
- llm_total_ms = (time.perf_counter() - t0) * 1000.0
493
- ttft_ms = float(first_token_ms or llm_total_ms)
494
- gen_ms = max(0.0, llm_total_ms - ttft_ms)
495
-
496
- # output tokens est (rough)
497
- if re.search(r"[\u4e00-\u9fff]", full_text or ""):
498
- output_tokens_est = int(output_chars / 2.0) if output_chars else 0
499
- else:
500
- output_tokens_est = int(output_chars / 4.0) if output_chars else 0
501
-
502
- tokens_per_sec_est = (
503
- (output_tokens_est / (gen_ms / 1000.0)) if gen_ms > 1 else None
504
- )
505
-
506
- prof = {
507
- "model": used_model,
508
- "streaming_used": True,
509
- "max_tokens": max_tokens,
510
- "output_tokens_est": output_tokens_est,
511
- "tokens_per_sec_est": tokens_per_sec_est,
512
- "ttft_ms": ttft_ms,
513
- "gen_ms": gen_ms,
514
- "llm_total_ms": llm_total_ms,
515
- }
516
- return full_text, prof
517
-
518
  except Exception as e:
519
- last_error = e
520
  print(
521
- f"[safe_chat_completion_profiled][{op}] attempt {attempt+1} failed: {repr(e)}"
522
  )
523
- if attempt == 1:
 
524
  break
525
 
526
- return build_error_message(last_error or Exception("unknown"), lang, op), {
527
- "model": used_model,
528
- "streaming_used": True,
529
- "max_tokens": max_tokens,
530
- "output_tokens_est": 0,
531
- "tokens_per_sec_est": None,
532
- "ttft_ms": None,
533
- "gen_ms": None,
534
- "llm_total_ms": (time.perf_counter() - t0) * 1000.0,
535
- "error": repr(last_error) if last_error else "unknown",
536
- }
537
 
538
 
539
  def build_messages(
@@ -546,13 +355,8 @@ def build_messages(
546
  weaknesses: Optional[List[str]],
547
  cognitive_state: Optional[Dict[str, int]],
548
  rag_context: Optional[str] = None,
549
- model_name: str = "",
550
  ) -> List[Dict[str, str]]:
551
- model_for_count = model_name_or_default(model_name)
552
-
553
- messages: List[Dict[str, str]] = [
554
- {"role": "system", "content": CLARE_SYSTEM_PROMPT}
555
- ]
556
 
557
  if learning_mode in LEARNING_MODE_INSTRUCTIONS:
558
  mode_instruction = LEARNING_MODE_INSTRUCTIONS[learning_mode]
@@ -579,9 +383,7 @@ def build_messages(
579
  messages.append(
580
  {
581
  "role": "system",
582
- "content": (
583
- f"The student also uploaded a {doc_type} document as supporting material."
584
- ),
585
  }
586
  )
587
 
@@ -624,22 +426,15 @@ def build_messages(
624
  )
625
  messages.append({"role": "system", "content": "Session memory: " + session_memory_text})
626
 
627
- # RAG context: enforce token cap here
628
- rag_text_for_prompt = ""
629
  if rag_context:
630
- rag_text_for_prompt = _truncate_to_tokens(
631
- rag_context,
632
- max_tokens=MAX_RAG_TOKENS_IN_PROMPT,
633
- model=model_for_count,
634
- )
635
  messages.append(
636
  {
637
  "role": "system",
638
- "content": "Relevant excerpts (use as primary grounding):\n\n" + rag_text_for_prompt,
639
  }
640
  )
641
 
642
- # Only keep the last N turns for speed (HARD LIMIT)
643
  trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
644
  for user, assistant in trimmed_history:
645
  messages.append({"role": "user", "content": user})
@@ -647,18 +442,11 @@ def build_messages(
647
  messages.append({"role": "assistant", "content": assistant})
648
 
649
  messages.append({"role": "user", "content": user_message})
 
650
 
651
- # prompt token breakdown log
652
- _log_prompt_token_breakdown(
653
- messages=messages,
654
- system_prompt=CLARE_SYSTEM_PROMPT,
655
- rag_context=rag_text_for_prompt,
656
- trimmed_history=trimmed_history,
657
- user_message=user_message,
658
- model_name=model_for_count,
659
- )
660
 
661
- return messages
 
662
 
663
 
664
  @traceable(run_type="chain", name="chat_with_clare")
@@ -673,13 +461,7 @@ def chat_with_clare(
673
  weaknesses: Optional[List[str]],
674
  cognitive_state: Optional[Dict[str, int]],
675
  rag_context: Optional[str] = None,
676
- ) -> Tuple[str, List[Tuple[str, str]], Dict]:
677
- """
678
- Returns:
679
- answer: str
680
- history: List[(user, assistant)]
681
- llm_stats: Dict (TTFT + tokens/sec est + prompt token breakdown printed in logs)
682
- """
683
  try:
684
  set_run_metadata(
685
  learning_mode=learning_mode,
@@ -699,42 +481,19 @@ def chat_with_clare(
699
  weaknesses=weaknesses,
700
  cognitive_state=cognitive_state,
701
  rag_context=rag_context,
702
- model_name=model_name,
703
  )
704
 
705
- # IMPORTANT: pass messages + lang (fixes your HTTP 500)
706
- answer, prof = safe_chat_completion_profiled(
707
  model_name=model_name,
708
  messages=messages,
709
  lang=language_preference,
710
  op="chat",
711
  temperature=0.5,
712
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
713
- timeout=20,
714
  )
715
 
716
  history = history + [(message, answer)]
717
-
718
- llm_stats = {
719
- "llm_profile": {
720
- "model": prof.get("model"),
721
- "streaming_used": prof.get("streaming_used"),
722
- "max_tokens": prof.get("max_tokens"),
723
- "output_tokens_est": prof.get("output_tokens_est"),
724
- "tokens_per_sec_est": prof.get("tokens_per_sec_est"),
725
- },
726
- "marks_ms": {
727
- "llm_first_token": prof.get("ttft_ms"),
728
- "llm_done": prof.get("llm_total_ms"),
729
- },
730
- "segments_ms": {
731
- "llm_ttft_ms": prof.get("ttft_ms"),
732
- "llm_gen_ms": prof.get("gen_ms"),
733
- "llm_done": prof.get("llm_total_ms"),
734
- },
735
- }
736
-
737
- return answer, history, llm_stats
738
 
739
 
740
  def export_conversation(
@@ -783,10 +542,7 @@ def summarize_conversation(
783
 
784
  messages = [
785
  {"role": "system", "content": CLARE_SYSTEM_PROMPT},
786
- {
787
- "role": "system",
788
- "content": "Produce a concept-only summary. Use bullet points. No off-topic text.",
789
- },
790
  {"role": "system", "content": f"Course topics: {topics_text}"},
791
  {"role": "system", "content": f"Student difficulties: {weakness_text}"},
792
  {"role": "system", "content": f"Cognitive state: {cog_text}"},
@@ -796,13 +552,12 @@ def summarize_conversation(
796
  if language_preference == "中文":
797
  messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
798
 
799
- summary_text, _prof = safe_chat_completion_profiled(
800
  model_name=model_name,
801
  messages=messages,
802
  lang=language_preference,
803
  op="summary",
804
  temperature=0.4,
805
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
806
- timeout=20,
807
  )
808
  return summary_text
 
2
  import os
3
  import re
4
  import math
 
 
5
  from typing import List, Dict, Tuple, Optional
6
 
7
  from docx import Document
 
19
  from langsmith.run_helpers import set_run_metadata
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # ----------------------------
23
+ # Speed knobs (simple + stable)
24
  # ----------------------------
 
25
  MAX_HISTORY_TURNS = int(os.getenv("CLARE_MAX_HISTORY_TURNS", "10"))
26
+ MAX_RAG_CHARS_IN_PROMPT = int(os.getenv("CLARE_MAX_RAG_CHARS", "2000"))
 
 
 
 
 
27
  DEFAULT_MAX_OUTPUT_TOKENS = int(os.getenv("CLARE_MAX_OUTPUT_TOKENS", "384"))
28
 
29
 
 
152
  parts.append("Cognitive state: " + describe_cognitive_state(cognitive_state))
153
 
154
  if not parts:
155
+ return "No prior session memory. Start with a short explanation and ask a quick check-up question."
 
 
156
 
157
  return " | ".join(parts)
158
 
 
166
  return "English"
167
 
168
 
169
+ def build_error_message(e: Exception, lang: str, op: str = "chat") -> str:
 
 
 
 
170
  if lang == "中文":
171
  prefix = {
172
  "chat": "抱歉,刚刚在和模型对话时出现了一点问题。",
 
310
  return None
311
 
312
 
313
+ @traceable(run_type="llm", name="safe_chat_completion")
314
+ def safe_chat_completion(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  model_name: str,
316
  messages: List[Dict[str, str]],
317
  lang: str,
318
  op: str = "chat",
319
  temperature: float = 0.5,
320
  max_tokens: Optional[int] = None,
321
+ ) -> str:
 
 
 
 
 
 
 
 
 
322
  preferred_model = model_name_or_default(model_name)
 
 
 
323
  last_error: Optional[Exception] = None
324
+ max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
325
 
326
  for attempt in range(2):
327
+ current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
328
  try:
329
+ resp = client.chat.completions.create(
330
+ model=current_model,
 
 
 
 
331
  messages=messages,
332
  temperature=temperature,
333
  max_tokens=max_tokens,
334
+ timeout=20,
 
335
  )
336
+ return resp.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  except Exception as e:
 
338
  print(
339
+ f"[safe_chat_completion][{op}] attempt {attempt+1} failed with model={current_model}: {repr(e)}"
340
  )
341
+ last_error = e
342
+ if current_model == DEFAULT_MODEL or attempt == 1:
343
  break
344
 
345
+ return build_error_message(last_error or Exception("unknown error"), lang, op)
 
 
 
 
 
 
 
 
 
 
346
 
347
 
348
  def build_messages(
 
355
  weaknesses: Optional[List[str]],
356
  cognitive_state: Optional[Dict[str, int]],
357
  rag_context: Optional[str] = None,
 
358
  ) -> List[Dict[str, str]]:
359
+ messages: List[Dict[str, str]] = [{"role": "system", "content": CLARE_SYSTEM_PROMPT}]
 
 
 
 
360
 
361
  if learning_mode in LEARNING_MODE_INSTRUCTIONS:
362
  mode_instruction = LEARNING_MODE_INSTRUCTIONS[learning_mode]
 
383
  messages.append(
384
  {
385
  "role": "system",
386
+ "content": f"The student also uploaded a {doc_type} document as supporting material.",
 
 
387
  }
388
  )
389
 
 
426
  )
427
  messages.append({"role": "system", "content": "Session memory: " + session_memory_text})
428
 
 
 
429
  if rag_context:
430
+ rc = (rag_context or "")[:MAX_RAG_CHARS_IN_PROMPT]
 
 
 
 
431
  messages.append(
432
  {
433
  "role": "system",
434
+ "content": "Relevant excerpts (use as primary grounding):\n\n" + rc,
435
  }
436
  )
437
 
 
438
  trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
439
  for user, assistant in trimmed_history:
440
  messages.append({"role": "user", "content": user})
 
442
  messages.append({"role": "assistant", "content": assistant})
443
 
444
  messages.append({"role": "user", "content": user_message})
445
+ return messages
446
 
 
 
 
 
 
 
 
 
 
447
 
448
+ def model_name_or_default(x: str) -> str:
449
+ return (x or "").strip() or DEFAULT_MODEL
450
 
451
 
452
  @traceable(run_type="chain", name="chat_with_clare")
 
461
  weaknesses: Optional[List[str]],
462
  cognitive_state: Optional[Dict[str, int]],
463
  rag_context: Optional[str] = None,
464
+ ) -> Tuple[str, List[Tuple[str, str]]]:
 
 
 
 
 
 
465
  try:
466
  set_run_metadata(
467
  learning_mode=learning_mode,
 
481
  weaknesses=weaknesses,
482
  cognitive_state=cognitive_state,
483
  rag_context=rag_context,
 
484
  )
485
 
486
+ answer = safe_chat_completion(
 
487
  model_name=model_name,
488
  messages=messages,
489
  lang=language_preference,
490
  op="chat",
491
  temperature=0.5,
492
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
 
493
  )
494
 
495
  history = history + [(message, answer)]
496
+ return answer, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
 
499
  def export_conversation(
 
542
 
543
  messages = [
544
  {"role": "system", "content": CLARE_SYSTEM_PROMPT},
545
+ {"role": "system", "content": "Produce a concept-only summary. Use bullet points. No off-topic text."},
 
 
 
546
  {"role": "system", "content": f"Course topics: {topics_text}"},
547
  {"role": "system", "content": f"Student difficulties: {weakness_text}"},
548
  {"role": "system", "content": f"Cognitive state: {cog_text}"},
 
552
  if language_preference == "中文":
553
  messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
554
 
555
+ summary_text = safe_chat_completion(
556
  model_name=model_name,
557
  messages=messages,
558
  lang=language_preference,
559
  op="summary",
560
  temperature=0.4,
561
  max_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
 
562
  )
563
  return summary_text