Spaces:
Sleeping
Sleeping
Update api/clare_core.py
Browse files- api/clare_core.py +226 -16
api/clare_core.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import math
|
|
|
|
|
|
|
| 5 |
from typing import List, Dict, Tuple, Optional
|
| 6 |
|
| 7 |
from docx import Document
|
|
@@ -18,12 +20,87 @@ from .config import (
|
|
| 18 |
from langsmith import traceable
|
| 19 |
from langsmith.run_helpers import set_run_metadata
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ----------------------------
|
| 22 |
-
# Speed knobs
|
| 23 |
# ----------------------------
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# ---------- syllabus 解析 ----------
|
|
@@ -315,21 +392,56 @@ def find_similar_past_question(
|
|
| 315 |
return None
|
| 316 |
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
model_name: str,
|
| 321 |
messages: List[Dict[str, str]],
|
| 322 |
lang: str,
|
| 323 |
op: str = "chat",
|
| 324 |
temperature: float = 0.5,
|
| 325 |
max_tokens: Optional[int] = None,
|
| 326 |
-
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
preferred_model = model_name or DEFAULT_MODEL
|
| 328 |
-
last_error: Optional[Exception] = None
|
| 329 |
max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
|
| 330 |
|
|
|
|
|
|
|
| 331 |
for attempt in range(2):
|
| 332 |
current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
|
|
|
|
|
|
|
|
|
|
| 333 |
try:
|
| 334 |
resp = client.chat.completions.create(
|
| 335 |
model=current_model,
|
|
@@ -337,18 +449,97 @@ def safe_chat_completion(
|
|
| 337 |
temperature=temperature,
|
| 338 |
max_tokens=max_tokens,
|
| 339 |
timeout=20,
|
|
|
|
| 340 |
)
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
except Exception as e:
|
| 343 |
print(
|
| 344 |
-
f"[
|
| 345 |
f"failed with model={current_model}: {repr(e)}"
|
| 346 |
)
|
| 347 |
last_error = e
|
| 348 |
if current_model == DEFAULT_MODEL or attempt == 1:
|
| 349 |
break
|
| 350 |
|
| 351 |
-
return build_error_message(last_error or Exception("unknown error"), lang, op)
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
|
| 354 |
def build_messages(
|
|
@@ -447,18 +638,22 @@ def build_messages(
|
|
| 447 |
}
|
| 448 |
)
|
| 449 |
|
|
|
|
|
|
|
| 450 |
if rag_context:
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
messages.append(
|
| 453 |
{
|
| 454 |
"role": "system",
|
| 455 |
"content": (
|
| 456 |
-
"Relevant excerpts (use as primary grounding):\n\n" +
|
| 457 |
),
|
| 458 |
}
|
| 459 |
)
|
| 460 |
|
| 461 |
-
# Only keep the last N turns for speed
|
| 462 |
trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
|
| 463 |
for user, assistant in trimmed_history:
|
| 464 |
messages.append({"role": "user", "content": user})
|
|
@@ -466,9 +661,24 @@ def build_messages(
|
|
| 466 |
messages.append({"role": "assistant", "content": assistant})
|
| 467 |
|
| 468 |
messages.append({"role": "user", "content": user_message})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
return messages
|
| 470 |
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
@traceable(run_type="chain", name="chat_with_clare")
|
| 473 |
def chat_with_clare(
|
| 474 |
message: str,
|
|
@@ -503,7 +713,7 @@ def chat_with_clare(
|
|
| 503 |
rag_context=rag_context,
|
| 504 |
)
|
| 505 |
|
| 506 |
-
answer =
|
| 507 |
model_name=model_name,
|
| 508 |
messages=messages,
|
| 509 |
lang=language_preference,
|
|
@@ -577,7 +787,7 @@ def summarize_conversation(
|
|
| 577 |
if language_preference == "中文":
|
| 578 |
messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
|
| 579 |
|
| 580 |
-
summary_text =
|
| 581 |
model_name=model_name,
|
| 582 |
messages=messages,
|
| 583 |
lang=language_preference,
|
|
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import math
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
from typing import List, Dict, Tuple, Optional
|
| 8 |
|
| 9 |
from docx import Document
|
|
|
|
| 20 |
from langsmith import traceable
|
| 21 |
from langsmith.run_helpers import set_run_metadata
|
| 22 |
|
| 23 |
+
|
| 24 |
+
# ============================
|
| 25 |
+
# Token helpers (optional tiktoken)
|
| 26 |
+
# ============================
|
| 27 |
+
def _safe_import_tiktoken():
|
| 28 |
+
try:
|
| 29 |
+
import tiktoken # type: ignore
|
| 30 |
+
return tiktoken
|
| 31 |
+
except Exception:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _approx_tokens(text: str) -> int:
|
| 36 |
+
if not text:
|
| 37 |
+
return 0
|
| 38 |
+
return max(1, int(len(text) / 4))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _count_text_tokens(text: str, model: str = "") -> int:
|
| 42 |
+
tk = _safe_import_tiktoken()
|
| 43 |
+
if tk is None:
|
| 44 |
+
return _approx_tokens(text)
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
|
| 48 |
+
except Exception:
|
| 49 |
+
enc = tk.get_encoding("cl100k_base")
|
| 50 |
+
|
| 51 |
+
return len(enc.encode(text or ""))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _count_messages_tokens(messages: List[Dict[str, str]], model: str = "") -> int:
|
| 55 |
+
# engineering approximation for chat messages overhead
|
| 56 |
+
total = 0
|
| 57 |
+
for m in messages or []:
|
| 58 |
+
total += 4
|
| 59 |
+
total += _count_text_tokens(str(m.get("role", "")), model=model)
|
| 60 |
+
total += _count_text_tokens(str(m.get("content", "")), model=model)
|
| 61 |
+
total += 2
|
| 62 |
+
return total
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
|
| 66 |
+
if not text:
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
tk = _safe_import_tiktoken()
|
| 70 |
+
if tk is None:
|
| 71 |
+
total = _approx_tokens(text)
|
| 72 |
+
if total <= max_tokens:
|
| 73 |
+
return text
|
| 74 |
+
ratio = max_tokens / max(1, total)
|
| 75 |
+
cut = max(50, min(len(text), int(len(text) * ratio)))
|
| 76 |
+
s = text[:cut]
|
| 77 |
+
while _approx_tokens(s) > max_tokens and len(s) > 50:
|
| 78 |
+
s = s[: int(len(s) * 0.9)]
|
| 79 |
+
return s
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
|
| 83 |
+
except Exception:
|
| 84 |
+
enc = tk.get_encoding("cl100k_base")
|
| 85 |
+
|
| 86 |
+
ids = enc.encode(text or "")
|
| 87 |
+
if len(ids) <= max_tokens:
|
| 88 |
+
return text
|
| 89 |
+
return enc.decode(ids[:max_tokens])
|
| 90 |
+
|
| 91 |
+
|
| 92 |
# ----------------------------
|
| 93 |
+
# Speed knobs (HARD LIMITS)
|
| 94 |
# ----------------------------
|
| 95 |
+
# 1) history 最近 10 轮
|
| 96 |
+
MAX_HISTORY_TURNS = int(os.getenv("CLARE_MAX_HISTORY_TURNS", "10"))
|
| 97 |
+
|
| 98 |
+
# 2) rag 最多 4 条每条 500 tokens 已在 rag_engine.py 实现
|
| 99 |
+
# 这里仅控制“注入到 LLM prompt 的总 tokens”,避免 prompt 爆炸
|
| 100 |
+
MAX_RAG_TOKENS_IN_PROMPT = int(os.getenv("CLARE_MAX_RAG_TOKENS", "2000"))
|
| 101 |
+
|
| 102 |
+
# 3) max_new_tokens 默认 384
|
| 103 |
+
DEFAULT_MAX_OUTPUT_TOKENS = int(os.getenv("CLARE_MAX_OUTPUT_TOKENS", "384"))
|
| 104 |
|
| 105 |
|
| 106 |
# ---------- syllabus 解析 ----------
|
|
|
|
| 392 |
return None
|
| 393 |
|
| 394 |
|
| 395 |
+
def _log_prompt_token_breakdown(
|
| 396 |
+
messages: List[Dict[str, str]],
|
| 397 |
+
system_prompt: str,
|
| 398 |
+
rag_context: str,
|
| 399 |
+
trimmed_history: List[Tuple[str, str]],
|
| 400 |
+
user_message: str,
|
| 401 |
+
model_name: str,
|
| 402 |
+
):
|
| 403 |
+
stats = {
|
| 404 |
+
"system_tokens": _count_text_tokens(system_prompt, model=model_name),
|
| 405 |
+
"rag_tokens": _count_text_tokens(rag_context or "", model=model_name),
|
| 406 |
+
"history_tokens": sum(
|
| 407 |
+
_count_text_tokens(u or "", model=model_name) + _count_text_tokens(a or "", model=model_name)
|
| 408 |
+
for u, a in (trimmed_history or [])
|
| 409 |
+
),
|
| 410 |
+
"user_tokens": _count_text_tokens(user_message or "", model=model_name),
|
| 411 |
+
"prompt_tokens_total_est": _count_messages_tokens(messages, model=model_name),
|
| 412 |
+
"history_turns_kept": len(trimmed_history or []),
|
| 413 |
+
"max_rag_tokens_in_prompt": MAX_RAG_TOKENS_IN_PROMPT,
|
| 414 |
+
"max_output_tokens": DEFAULT_MAX_OUTPUT_TOKENS,
|
| 415 |
+
"model": model_name or DEFAULT_MODEL,
|
| 416 |
+
}
|
| 417 |
+
print("[LLM_PROMPT_TOKENS] " + json.dumps(stats, ensure_ascii=False))
|
| 418 |
+
return stats
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@traceable(run_type="llm", name="safe_chat_completion_profiled")
|
| 422 |
+
def safe_chat_completion_profiled(
|
| 423 |
model_name: str,
|
| 424 |
messages: List[Dict[str, str]],
|
| 425 |
lang: str,
|
| 426 |
op: str = "chat",
|
| 427 |
temperature: float = 0.5,
|
| 428 |
max_tokens: Optional[int] = None,
|
| 429 |
+
) -> Tuple[str, Dict]:
|
| 430 |
+
"""
|
| 431 |
+
Returns:
|
| 432 |
+
- answer text
|
| 433 |
+
- profiling dict {ttft_ms, llm_total_ms, gen_ms, output_tokens_est, tokens_per_sec_est, streaming_used}
|
| 434 |
+
"""
|
| 435 |
preferred_model = model_name or DEFAULT_MODEL
|
|
|
|
| 436 |
max_tokens = int(max_tokens or DEFAULT_MAX_OUTPUT_TOKENS)
|
| 437 |
|
| 438 |
+
last_error: Optional[Exception] = None
|
| 439 |
+
|
| 440 |
for attempt in range(2):
|
| 441 |
current_model = preferred_model if attempt == 0 else DEFAULT_MODEL
|
| 442 |
+
|
| 443 |
+
# 1) Try streaming for real TTFT
|
| 444 |
+
t0 = time.perf_counter()
|
| 445 |
try:
|
| 446 |
resp = client.chat.completions.create(
|
| 447 |
model=current_model,
|
|
|
|
| 449 |
temperature=temperature,
|
| 450 |
max_tokens=max_tokens,
|
| 451 |
timeout=20,
|
| 452 |
+
stream=True,
|
| 453 |
)
|
| 454 |
+
|
| 455 |
+
first_token_t = None
|
| 456 |
+
out_parts: List[str] = []
|
| 457 |
+
for event in resp:
|
| 458 |
+
# OpenAI-style: event.choices[0].delta.content
|
| 459 |
+
try:
|
| 460 |
+
delta = event.choices[0].delta.content # type: ignore
|
| 461 |
+
except Exception:
|
| 462 |
+
delta = None
|
| 463 |
+
if not delta:
|
| 464 |
+
continue
|
| 465 |
+
if first_token_t is None:
|
| 466 |
+
first_token_t = time.perf_counter()
|
| 467 |
+
out_parts.append(delta)
|
| 468 |
+
|
| 469 |
+
t_end = time.perf_counter()
|
| 470 |
+
answer = "".join(out_parts)
|
| 471 |
+
|
| 472 |
+
ttft_ms = None if first_token_t is None else (first_token_t - t0) * 1000.0
|
| 473 |
+
total_ms = (t_end - t0) * 1000.0
|
| 474 |
+
gen_ms = None if first_token_t is None else (t_end - first_token_t) * 1000.0
|
| 475 |
+
out_tokens = _count_text_tokens(answer, model=current_model)
|
| 476 |
+
tokens_per_sec = None
|
| 477 |
+
if gen_ms and gen_ms > 0:
|
| 478 |
+
tokens_per_sec = out_tokens / (gen_ms / 1000.0)
|
| 479 |
+
|
| 480 |
+
prof = {
|
| 481 |
+
"streaming_used": True,
|
| 482 |
+
"ttft_ms": ttft_ms,
|
| 483 |
+
"llm_total_ms": total_ms,
|
| 484 |
+
"gen_ms": gen_ms,
|
| 485 |
+
"output_tokens_est": out_tokens,
|
| 486 |
+
"tokens_per_sec_est": tokens_per_sec,
|
| 487 |
+
"model": current_model,
|
| 488 |
+
"max_tokens": max_tokens,
|
| 489 |
+
}
|
| 490 |
+
print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
|
| 491 |
+
return answer, prof
|
| 492 |
+
|
| 493 |
+
except Exception as e:
|
| 494 |
+
last_error = e
|
| 495 |
+
# fall through to non-stream fallback below
|
| 496 |
+
|
| 497 |
+
# 2) Non-stream fallback (TTFT not available; approximate)
|
| 498 |
+
try:
|
| 499 |
+
t0 = time.perf_counter()
|
| 500 |
+
resp2 = client.chat.completions.create(
|
| 501 |
+
model=current_model,
|
| 502 |
+
messages=messages,
|
| 503 |
+
temperature=temperature,
|
| 504 |
+
max_tokens=max_tokens,
|
| 505 |
+
timeout=20,
|
| 506 |
+
)
|
| 507 |
+
t_end = time.perf_counter()
|
| 508 |
+
answer = resp2.choices[0].message.content or ""
|
| 509 |
+
|
| 510 |
+
total_ms = (t_end - t0) * 1000.0
|
| 511 |
+
out_tokens = _count_text_tokens(answer, model=current_model)
|
| 512 |
+
tokens_per_sec = None
|
| 513 |
+
if total_ms > 0:
|
| 514 |
+
tokens_per_sec = out_tokens / (total_ms / 1000.0)
|
| 515 |
+
|
| 516 |
+
prof = {
|
| 517 |
+
"streaming_used": False,
|
| 518 |
+
"ttft_ms": None, # not measurable without stream
|
| 519 |
+
"llm_total_ms": total_ms,
|
| 520 |
+
"gen_ms": None,
|
| 521 |
+
"output_tokens_est": out_tokens,
|
| 522 |
+
"tokens_per_sec_est": tokens_per_sec,
|
| 523 |
+
"model": current_model,
|
| 524 |
+
"max_tokens": max_tokens,
|
| 525 |
+
"note": "non-stream fallback; ttft_ms unavailable",
|
| 526 |
+
}
|
| 527 |
+
print("[LLM_PROFILING] " + json.dumps(prof, ensure_ascii=False))
|
| 528 |
+
return answer, prof
|
| 529 |
+
|
| 530 |
except Exception as e:
|
| 531 |
print(
|
| 532 |
+
f"[safe_chat_completion_profiled][{op}] attempt {attempt+1} "
|
| 533 |
f"failed with model={current_model}: {repr(e)}"
|
| 534 |
)
|
| 535 |
last_error = e
|
| 536 |
if current_model == DEFAULT_MODEL or attempt == 1:
|
| 537 |
break
|
| 538 |
|
| 539 |
+
return build_error_message(last_error or Exception("unknown error"), lang, op), {
|
| 540 |
+
"streaming_used": False,
|
| 541 |
+
"error": repr(last_error) if last_error else "unknown",
|
| 542 |
+
}
|
| 543 |
|
| 544 |
|
| 545 |
def build_messages(
|
|
|
|
| 638 |
}
|
| 639 |
)
|
| 640 |
|
| 641 |
+
# RAG context: enforce token cap here (in addition to rag_engine caps)
|
| 642 |
+
rag_text_for_prompt = ""
|
| 643 |
if rag_context:
|
| 644 |
+
rag_text_for_prompt = _truncate_to_tokens(
|
| 645 |
+
rag_context, max_tokens=MAX_RAG_TOKENS_IN_PROMPT, model=model_name_or_default(DEFAULT_MODEL)
|
| 646 |
+
)
|
| 647 |
messages.append(
|
| 648 |
{
|
| 649 |
"role": "system",
|
| 650 |
"content": (
|
| 651 |
+
"Relevant excerpts (use as primary grounding):\n\n" + rag_text_for_prompt
|
| 652 |
),
|
| 653 |
}
|
| 654 |
)
|
| 655 |
|
| 656 |
+
# Only keep the last N turns for speed (HARD LIMIT)
|
| 657 |
trimmed_history = history[-MAX_HISTORY_TURNS:] if history else []
|
| 658 |
for user, assistant in trimmed_history:
|
| 659 |
messages.append({"role": "user", "content": user})
|
|
|
|
| 661 |
messages.append({"role": "assistant", "content": assistant})
|
| 662 |
|
| 663 |
messages.append({"role": "user", "content": user_message})
|
| 664 |
+
|
| 665 |
+
# prompt token breakdown log
|
| 666 |
+
_log_prompt_token_breakdown(
|
| 667 |
+
messages=messages,
|
| 668 |
+
system_prompt=CLARE_SYSTEM_PROMPT,
|
| 669 |
+
rag_context=rag_text_for_prompt,
|
| 670 |
+
trimmed_history=trimmed_history,
|
| 671 |
+
user_message=user_message,
|
| 672 |
+
model_name=(DEFAULT_MODEL or ""),
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
return messages
|
| 676 |
|
| 677 |
|
| 678 |
+
def model_name_or_default(x: str) -> str:
|
| 679 |
+
return x or DEFAULT_MODEL
|
| 680 |
+
|
| 681 |
+
|
| 682 |
@traceable(run_type="chain", name="chat_with_clare")
|
| 683 |
def chat_with_clare(
|
| 684 |
message: str,
|
|
|
|
| 713 |
rag_context=rag_context,
|
| 714 |
)
|
| 715 |
|
| 716 |
+
answer, _prof = safe_chat_completion_profiled(
|
| 717 |
model_name=model_name,
|
| 718 |
messages=messages,
|
| 719 |
lang=language_preference,
|
|
|
|
| 787 |
if language_preference == "中文":
|
| 788 |
messages.append({"role": "system", "content": "请用中文输出要点总结(bullet points)。"})
|
| 789 |
|
| 790 |
+
summary_text, _prof = safe_chat_completion_profiled(
|
| 791 |
model_name=model_name,
|
| 792 |
messages=messages,
|
| 793 |
lang=language_preference,
|