SarahXia0405 commited on
Commit
34ec5a2
·
verified ·
1 Parent(s): 3268902

Update api/clare_core.py

Browse files
Files changed (1) hide show
  1. api/clare_core.py +226 -16
api/clare_core.py CHANGED
@@ -2,6 +2,8 @@
2
  import os
3
  import re
4
  import math
 
 
5
  from typing import List, Dict, Tuple, Optional
6
 
7
  from docx import Document
@@ -18,12 +20,87 @@ from .config import (
18
  from langsmith import traceable
19
  from langsmith.run_helpers import set_run_metadata
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ----------------------------
22
- # Speed knobs
23
  # ----------------------------
24
- MAX_HISTORY_TURNS = int(os.getenv("CLARE_MAX_HISTORY_TURNS", "4")) # was 6
25
- MAX_RAG_CHARS_IN_PROMPT = int(os.getenv("CLARE_MAX_RAG_CHARS", "600")) # was 1200
26
- DEFAULT_MAX_OUTPUT_TOKENS = int(os.getenv("CLARE_MAX_OUTPUT_TOKENS", "450"))
 
 
 
 
 
 
27
 
28
 
29
  # ---------- syllabus 解析 ----------
@@ -315,21 +392,56 @@ def find_similar_past_question(
315
  return None
316
 
317
 
318
- @traceable(run_type="llm", name="safe_chat_completion")
319
- def safe_chat_completion(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  model_name: str,
321
  messages: List[Dict[str, str]],
322
  lang: str,
323
  op: str = "chat",
324
  temperature: float = 0.5,
325
  max_tokens: Optional[int] = None,
326
- ) -> str:
 
 
 
 
 
327
  preferred_model = model_name or DEFAULT_MODEL
328
- last_error: Optional[Exception] = None
329
  max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
330
 
 
 
331
  for attempt in range(2):
332
  current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
 
 
 
333
  try:
334
  resp = client.chat.completions.create(
335
  model=current_model,
@@ -337,18 +449,97 @@ def safe_chat_completion(
337
  temperature=temperature,
338
  max_tokens=max_tokens,
339
  timeout=20,
 
340
  )
341
- return resp.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  except Exception as e:
343
  print(
344
- f"[safe_chat_completion][{op}] attempt {attempt+1} "
345
  f"failed with model={current_model}: {repr(e)}"
346
  )
347
  last_error = e
348
  if current_model == DEFAULT_MODEL or attempt == 1:
349
  break
350
 
351
- return build_error_message(last_error or Exception("unknown error"), lang, op)
 
 
 
352
 
353
 
354
  def build_messages(
@@ -447,18 +638,22 @@ def build_messages(
447
  }
448
  )
449
 
 
 
450
  if rag_context:
451
- rc = rag_context[:MAX_RAG_CHARS_IN_PROMPT]
 
 
452
  messages.append(
453
  {
454
  "role": "system",
455
  "content": (
456
- "Relevant excerpts (use as primary grounding):\n\n" + rc
457
  ),
458
  }
459
  )
460
 
461
- # Only keep the last N turns for speed
462
  trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
463
  for user, assistant in trimmed_history:
464
  messages.append({"role": "user", "content": user})
@@ -466,9 +661,24 @@ def build_messages(
466
  messages.append({"role": "assistant", "content": assistant})
467
 
468
  messages.append({"role": "user", "content": user_message})
 
 
 
 
 
 
 
 
 
 
 
469
  return messages
470
 
471
 
 
 
 
 
472
  @traceable(run_type="chain", name="chat_with_clare")
473
  def chat_with_clare(
474
  message: str,
@@ -503,7 +713,7 @@ def chat_with_clare(
503
  rag_context=rag_context,
504
  )
505
 
506
- answer = safe_chat_completion(
507
  model_name=model_name,
508
  messages=messages,
509
  lang=language_preference,
@@ -577,7 +787,7 @@ def summarize_conversation(
577
  if language_preference == "中文":
578
  messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
579
 
580
- summary_text = safe_chat_completion(
581
  model_name=model_name,
582
  messages=messages,
583
  lang=language_preference,
 
2
  import os
3
  import re
4
  import math
5
+ import time
6
+ import json
7
  from typing import List, Dict, Tuple, Optional
8
 
9
  from docx import Document
 
20
  from langsmith import traceable
21
  from langsmith.run_helpers import set_run_metadata
22
 
23
+
24
+ # ============================
25
+ # Token helpers (optional tiktoken)
26
+ # ============================
27
+ def _safe_import_tiktoken():
28
+ try:
29
+ import tiktoken # type: ignore
30
+ return tiktoken
31
+ except Exception:
32
+ return None
33
+
34
+
35
+ def _approx_tokens(text: str) -> int:
36
+ if not text:
37
+ return 0
38
+ return max(1, int(len(text) / 4))
39
+
40
+
41
+ def _count_text_tokens(text: str, model: str = "") -> int:
42
+ tk = _safe_import_tiktoken()
43
+ if tk is None:
44
+ return _approx_tokens(text)
45
+
46
+ try:
47
+ enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
48
+ except Exception:
49
+ enc = tk.get_encoding("cl100k_base")
50
+
51
+ return len(enc.encode(text or ""))
52
+
53
+
54
+ def _count_messages_tokens(messages: List[Dict[str, str]], model: str = "") -> int:
55
+ # engineering approximation for chat messages overhead
56
+ total = 0
57
+ for m in messages or []:
58
+ total += 4
59
+ total += _count_text_tokens(str(m.get("role", "")), model=model)
60
+ total += _count_text_tokens(str(m.get("content", "")), model=model)
61
+ total += 2
62
+ return total
63
+
64
+
65
+ def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
66
+ if not text:
67
+ return text
68
+
69
+ tk = _safe_import_tiktoken()
70
+ if tk is None:
71
+ total = _approx_tokens(text)
72
+ if total <= max_tokens:
73
+ return text
74
+ ratio = max_tokens / max(1, total)
75
+ cut = max(50, min(len(text), int(len(text) * ratio)))
76
+ s = text[:cut]
77
+ while _approx_tokens(s) > max_tokens and len(s) > 50:
78
+ s = s[: int(len(s) * 0.9)]
79
+ return s
80
+
81
+ try:
82
+ enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
83
+ except Exception:
84
+ enc = tk.get_encoding("cl100k_base")
85
+
86
+ ids = enc.encode(text or "")
87
+ if len(ids) <= max_tokens:
88
+ return text
89
+ return enc.decode(ids[:max_tokens])
90
+
91
+
92
  # ----------------------------
93
+ # Speed knobs (HARD LIMITS)
94
  # ----------------------------
95
+ # 1) history 最近 10
96
+ MAX_HISTORY_TURNS = int(os.getenv("CLARE_MAX_HISTORY_TURNS", "10"))
97
+
98
+ # 2) rag 最多 4 条每条 500 tokens 已在 rag_engine.py 实现
99
+ # 这里仅控制“注入到 LLM prompt 的总 tokens”,避免 prompt 爆炸
100
+ MAX_RAG_TOKENS_IN_PROMPT = int(os.getenv("CLARE_MAX_RAG_TOKENS", "2000"))
101
+
102
+ # 3) max_new_tokens 默认 384
103
+ DEFAULT_MAX_OUTPUT_TOKENS = int(os.getenv("CLARE_MAX_OUTPUT_TOKENS", "384"))
104
 
105
 
106
  # ---------- syllabus 解析 ----------
 
392
  return None
393
 
394
 
395
+ def _log_prompt_token_breakdown(
396
+ messages: List[Dict[str, str]],
397
+ system_prompt: str,
398
+ rag_context: str,
399
+ trimmed_history: List[Tuple[str, str]],
400
+ user_message: str,
401
+ model_name: str,
402
+ ):
403
+ stats = {
404
+ "system_tokens": _count_text_tokens(system_prompt, model=model_name),
405
+ "rag_tokens": _count_text_tokens(rag_context or "", model=model_name),
406
+ "history_tokens": sum(
407
+ _count_text_tokens(u or "", model=model_name) + _count_text_tokens(a or "", model=model_name)
408
+ for u, a in (trimmed_history or [])
409
+ ),
410
+ "user_tokens": _count_text_tokens(user_message or "", model=model_name),
411
+ "prompt_tokens_total_est": _count_messages_tokens(messages, model=model_name),
412
+ "history_turns_kept": len(trimmed_history or []),
413
+ "max_rag_tokens_in_prompt": MAX_RAG_TOKENS_IN_PROMPT,
414
+ "max_output_tokens": DEFAULT_MAX_OUTPUT_TOKENS,
415
+ "model": model_name or DEFAULT_MODEL,
416
+ }
417
+ print("[LLM_PROMPT_TOKENS] " + json.dumps(stats, ensure_ascii=False))
418
+ return stats
419
+
420
+
421
+ @traceable(run_type="llm", name="safe_chat_completion_profiled")
422
+ def safe_chat_completion_profiled(
423
  model_name: str,
424
  messages: List[Dict[str, str]],
425
  lang: str,
426
  op: str = "chat",
427
  temperature: float = 0.5,
428
  max_tokens: Optional[int] = None,
429
+ ) -> Tuple[str, Dict]:
430
+ """
431
+ Returns:
432
+ - answer text
433
+ - profiling dict {ttft_ms, llm_total_ms, gen_ms, output_tokens_est, tokens_per_sec_est, streaming_used}
434
+ """
435
  preferred_model = model_name or DEFAULT_MODEL
 
436
  max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
437
 
438
+ last_error: Optional[Exception] = None
439
+
440
  for attempt in range(2):
441
  current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
442
+
443
+ # 1) Try streaming for real TTFT
444
+ t0 = time.perf_counter()
445
  try:
446
  resp = client.chat.completions.create(
447
  model=current_model,
 
449
  temperature=temperature,
450
  max_tokens=max_tokens,
451
  timeout=20,
452
+ stream=True,
453
  )
454
+
455
+ first_token_t = None
456
+ out_parts: List[str] = []
457
+ for event in resp:
458
+ # OpenAI-style: event.choices[0].delta.content
459
+ try:
460
+ delta = event.choices[0].delta.content # type: ignore
461
+ except Exception:
462
+ delta = None
463
+ if not delta:
464
+ continue
465
+ if first_token_t is None:
466
+ first_token_t = time.perf_counter()
467
+ out_parts.append(delta)
468
+
469
+ t_end = time.perf_counter()
470
+ answer = "".join(out_parts)
471
+
472
+ ttft_ms = None if first_token_t is None else (first_token_t - t0) * 1000.0
473
+ total_ms = (t_end - t0) * 1000.0
474
+ gen_ms = None if first_token_t is None else (t_end - first_token_t) * 1000.0
475
+ out_tokens = _count_text_tokens(answer, model=current_model)
476
+ tokens_per_sec = None
477
+ if gen_ms and gen_ms > 0:
478
+ tokens_per_sec = out_tokens / (gen_ms / 1000.0)
479
+
480
+ prof = {
481
+ "streaming_used": True,
482
+ "ttft_ms": ttft_ms,
483
+ "llm_total_ms": total_ms,
484
+ "gen_ms": gen_ms,
485
+ "output_tokens_est": out_tokens,
486
+ "tokens_per_sec_est": tokens_per_sec,
487
+ "model": current_model,
488
+ "max_tokens": max_tokens,
489
+ }
490
+ print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
491
+ return answer, prof
492
+
493
+ except Exception as e:
494
+ last_error = e
495
+ # fall through to non-stream fallback below
496
+
497
+ # 2) Non-stream fallback (TTFT not available; approximate)
498
+ try:
499
+ t0 = time.perf_counter()
500
+ resp2 = client.chat.completions.create(
501
+ model=current_model,
502
+ messages=messages,
503
+ temperature=temperature,
504
+ max_tokens=max_tokens,
505
+ timeout=20,
506
+ )
507
+ t_end = time.perf_counter()
508
+ answer = resp2.choices[0].message.content or ""
509
+
510
+ total_ms = (t_end - t0) * 1000.0
511
+ out_tokens = _count_text_tokens(answer, model=current_model)
512
+ tokens_per_sec = None
513
+ if total_ms > 0:
514
+ tokens_per_sec = out_tokens / (total_ms / 1000.0)
515
+
516
+ prof = {
517
+ "streaming_used": False,
518
+ "ttft_ms": None, # not measurable without stream
519
+ "llm_total_ms": total_ms,
520
+ "gen_ms": None,
521
+ "output_tokens_est": out_tokens,
522
+ "tokens_per_sec_est": tokens_per_sec,
523
+ "model": current_model,
524
+ "max_tokens": max_tokens,
525
+ "note": "non-stream fallback; ttft_ms unavailable",
526
+ }
527
+ print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
528
+ return answer, prof
529
+
530
  except Exception as e:
531
  print(
532
+ f"[safe_chat_completion_profiled][{op}] attempt {attempt+1} "
533
  f"failed with model={current_model}: {repr(e)}"
534
  )
535
  last_error = e
536
  if current_model == DEFAULT_MODEL or attempt == 1:
537
  break
538
 
539
+ return build_error_message(last_error or Exception("unknown error"), lang, op), {
540
+ "streaming_used": False,
541
+ "error": repr(last_error) if last_error else "unknown",
542
+ }
543
 
544
 
545
  def build_messages(
 
638
  }
639
  )
640
 
641
+ # RAG context: enforce token cap here (in addition to rag_engine caps)
642
+ rag_text_for_prompt = ""
643
  if rag_context:
644
+ rag_text_for_prompt = _truncate_to_tokens(
645
+ rag_context, max_tokens=MAX_RAG_TOKENS_IN_PROMPT, model=model_name_or_default(DEFAULT_MODEL)
646
+ )
647
  messages.append(
648
  {
649
  "role": "system",
650
  "content": (
651
+ "Relevant excerpts (use as primary grounding):\n\n" + rag_text_for_prompt
652
  ),
653
  }
654
  )
655
 
656
+ # Only keep the last N turns for speed (HARD LIMIT)
657
  trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
658
  for user, assistant in trimmed_history:
659
  messages.append({"role": "user", "content": user})
 
661
  messages.append({"role": "assistant", "content": assistant})
662
 
663
  messages.append({"role": "user", "content": user_message})
664
+
665
+ # prompt token breakdown log
666
+ _log_prompt_token_breakdown(
667
+ messages=messages,
668
+ system_prompt=CLARE_SYSTEM_PROMPT,
669
+ rag_context=rag_text_for_prompt,
670
+ trimmed_history=trimmed_history,
671
+ user_message=user_message,
672
+ model_name=(DEFAULT_MODEL or ""),
673
+ )
674
+
675
  return messages
676
 
677
 
678
+ def model_name_or_default(x: str) -> str:
679
+ return x or DEFAULT_MODEL
680
+
681
+
682
  @traceable(run_type="chain", name="chat_with_clare")
683
  def chat_with_clare(
684
  message: str,
 
713
  rag_context=rag_context,
714
  )
715
 
716
+ answer, _prof = safe_chat_completion_profiled(
717
  model_name=model_name,
718
  messages=messages,
719
  lang=language_preference,
 
787
  if language_preference == "中文":
788
  messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
789
 
790
+ summary_text, _prof = safe_chat_completion_profiled(
791
  model_name=model_name,
792
  messages=messages,
793
  lang=language_preference,