Commit ·
c135a78
1
Parent(s): fc85087
Upload ChatIPC.cpp
Browse files- ChatIPC.cpp +4 -603
ChatIPC.cpp
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
// ChatIPC := Chat Incremental Pattern Constructor
|
| 2 |
-
|
| 3 |
#include <algorithm>
|
| 4 |
#include <atomic>
|
| 5 |
#include <cctype>
|
|
@@ -18,8 +16,6 @@
|
|
| 18 |
#include <unordered_map>
|
| 19 |
#include <unordered_set>
|
| 20 |
#include <vector>
|
| 21 |
-
#include <cmath>
|
| 22 |
-
#include <limits>
|
| 23 |
|
| 24 |
#ifdef _OPENMP
|
| 25 |
#include <omp.h>
|
|
@@ -28,16 +24,14 @@ inline int omp_get_max_threads(){ return 1; }
|
|
| 28 |
inline int omp_get_thread_num(){ return 0; }
|
| 29 |
#endif
|
| 30 |
|
| 31 |
-
extern unsigned char dictionary_json[];
|
| 32 |
extern unsigned int dictionary_json_len;
|
| 33 |
|
| 34 |
-
// --------------------------- Short utility functions ----------------------
|
| 35 |
|
| 36 |
static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
|
| 37 |
static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
|
| 38 |
static inline void safe_flush(std::ostream &os){ os.flush(); }
|
| 39 |
|
| 40 |
-
// NEW: dictionary model, normalization, and English-rule helpers.
|
| 41 |
struct DictionaryEntry {
|
| 42 |
std::string pos;
|
| 43 |
std::string word;
|
|
@@ -263,7 +257,6 @@ static double english_rule_bonus(const std::string &context_tok, const std::stri
|
|
| 263 |
return bonus;
|
| 264 |
}
|
| 265 |
|
| 266 |
-
// Tokenize by whitespace
|
| 267 |
static std::vector<std::string> tokenize_whitespace(const std::string &s){
|
| 268 |
std::istringstream iss(s);
|
| 269 |
std::vector<std::string> out;
|
|
@@ -272,7 +265,6 @@ static std::vector<std::string> tokenize_whitespace(const std::string &s){
|
|
| 272 |
return out;
|
| 273 |
}
|
| 274 |
|
| 275 |
-
// Tokenize by non-alphanumeric characters (for definitions)
|
| 276 |
static std::vector<std::string> tokenize_non_alnum(const std::string &s){
|
| 277 |
std::vector<std::string> out; std::string cur;
|
| 278 |
for (char ch : s){
|
|
@@ -286,572 +278,6 @@ static std::vector<std::string> tokenize_non_alnum(const std::string &s){
|
|
| 286 |
return out;
|
| 287 |
}
|
| 288 |
|
| 289 |
-
// ---------------- Math feature: iterative only ----------------
|
| 290 |
-
|
| 291 |
-
enum class MathOp {
|
| 292 |
-
Unknown,
|
| 293 |
-
Add,
|
| 294 |
-
Sub,
|
| 295 |
-
Mul,
|
| 296 |
-
Div,
|
| 297 |
-
Lt,
|
| 298 |
-
Le,
|
| 299 |
-
Gt,
|
| 300 |
-
Ge,
|
| 301 |
-
Eq
|
| 302 |
-
};
|
| 303 |
-
|
| 304 |
-
struct MathValue {
|
| 305 |
-
bool valid = false;
|
| 306 |
-
bool is_bool = false;
|
| 307 |
-
double number = 0.0;
|
| 308 |
-
bool boolean = false;
|
| 309 |
-
};
|
| 310 |
-
|
| 311 |
-
static inline std::string lower_copy_str(const std::string &s) {
|
| 312 |
-
std::string out;
|
| 313 |
-
out.reserve(s.size());
|
| 314 |
-
for (char c : s) out.push_back(to_low(c));
|
| 315 |
-
return out;
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
static inline std::string trim_math_surface(const std::string &s) {
|
| 319 |
-
size_t b = 0, e = s.size();
|
| 320 |
-
while (b < e && std::isspace(static_cast<unsigned char>(s[b]))) ++b;
|
| 321 |
-
while (e > b && std::isspace(static_cast<unsigned char>(s[e - 1]))) --e;
|
| 322 |
-
|
| 323 |
-
while (b < e) {
|
| 324 |
-
unsigned char uc = static_cast<unsigned char>(s[b]);
|
| 325 |
-
char c = s[b];
|
| 326 |
-
if (!std::ispunct(uc) || c == '<' || c == '>' || c == '=' || c == '+' ||
|
| 327 |
-
c == '-' || c == '*' || c == '/' || c == '(' || c == ')')
|
| 328 |
-
break;
|
| 329 |
-
++b;
|
| 330 |
-
}
|
| 331 |
-
while (e > b) {
|
| 332 |
-
unsigned char uc = static_cast<unsigned char>(s[e - 1]);
|
| 333 |
-
char c = s[e - 1];
|
| 334 |
-
if (!std::ispunct(uc) || c == '<' || c == '>' || c == '=' || c == '+' ||
|
| 335 |
-
c == '-' || c == '*' || c == '/' || c == '(' || c == ')')
|
| 336 |
-
break;
|
| 337 |
-
--e;
|
| 338 |
-
}
|
| 339 |
-
|
| 340 |
-
std::string out;
|
| 341 |
-
out.reserve(e - b);
|
| 342 |
-
for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i]));
|
| 343 |
-
return out;
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
static inline bool is_number_surface(const std::string &s, double &value) {
|
| 347 |
-
if (s.empty()) return false;
|
| 348 |
-
bool has_digit = false;
|
| 349 |
-
bool has_dot = false;
|
| 350 |
-
|
| 351 |
-
for (char c : s) {
|
| 352 |
-
if (std::isdigit(static_cast<unsigned char>(c))) {
|
| 353 |
-
has_digit = true;
|
| 354 |
-
continue;
|
| 355 |
-
}
|
| 356 |
-
if (c == '.' && !has_dot) {
|
| 357 |
-
has_dot = true;
|
| 358 |
-
continue;
|
| 359 |
-
}
|
| 360 |
-
if ((c == '+' || c == '-') && &c == &s.front()) continue;
|
| 361 |
-
return false;
|
| 362 |
-
}
|
| 363 |
-
if (!has_digit) return false;
|
| 364 |
-
|
| 365 |
-
try {
|
| 366 |
-
size_t idx = 0;
|
| 367 |
-
value = std::stod(s, &idx);
|
| 368 |
-
return idx == s.size();
|
| 369 |
-
} catch (...) {
|
| 370 |
-
return false;
|
| 371 |
-
}
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
static inline bool is_sentence_break_token(const std::string &raw) {
|
| 375 |
-
if (raw.empty()) return false;
|
| 376 |
-
const char c = raw.back();
|
| 377 |
-
return c == '.' || c == '?' || c == '!' || c == ';' || c == ':';
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
static inline bool is_logic_token(const std::string &t) {
|
| 381 |
-
return t == "and" || t == "or" || t == "not";
|
| 382 |
-
}
|
| 383 |
-
|
| 384 |
-
static inline bool is_comparison_surface(const std::string &t) {
|
| 385 |
-
return t == "<" || t == "<=" || t == ">" || t == ">=" || t == "=" || t == "==";
|
| 386 |
-
}
|
| 387 |
-
|
| 388 |
-
static inline bool is_reversal_marker(const std::string &t) {
|
| 389 |
-
return t == "from" || t == "than" || t == "by";
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
static inline bool is_add_seed(const std::string &t) {
|
| 393 |
-
return t == "add" || t == "plus" || t == "sum";
|
| 394 |
-
}
|
| 395 |
-
|
| 396 |
-
static inline bool is_sub_seed(const std::string &t) {
|
| 397 |
-
return t == "subtract" || t == "minus" || t == "difference";
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
static inline bool is_mul_seed(const std::string &t) {
|
| 401 |
-
return t == "multiply" || t == "times" || t == "product";
|
| 402 |
-
}
|
| 403 |
-
|
| 404 |
-
static inline bool is_div_seed(const std::string &t) {
|
| 405 |
-
return t == "divide" || t == "quotient" || t == "over";
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
static inline bool is_lt_seed(const std::string &t) {
|
| 409 |
-
return t == "less" || t == "smaller" || t == "lower";
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
static inline bool is_gt_seed(const std::string &t) {
|
| 413 |
-
return t == "greater" || t == "larger" || t == "higher" || t == "more";
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
static inline bool is_le_seed(const std::string &t) {
|
| 417 |
-
return t == "atmost" || t == "at" || t == "no";
|
| 418 |
-
}
|
| 419 |
-
|
| 420 |
-
static inline bool is_ge_seed(const std::string &t) {
|
| 421 |
-
return t == "atleast";
|
| 422 |
-
}
|
| 423 |
-
|
| 424 |
-
static inline bool is_eq_seed(const std::string &t) {
|
| 425 |
-
return t == "equal" || t == "equals" || t == "same";
|
| 426 |
-
}
|
| 427 |
-
|
| 428 |
-
// Minimal bootstrap seeds only; dictionary closure does the expansion.
|
| 429 |
-
static const std::vector<std::string> &op_seed_lexicon(MathOp op) {
|
| 430 |
-
static const std::vector<std::string> add = {"add", "plus", "sum"};
|
| 431 |
-
static const std::vector<std::string> sub = {"subtract", "minus", "difference"};
|
| 432 |
-
static const std::vector<std::string> mul = {"multiply", "times", "product"};
|
| 433 |
-
static const std::vector<std::string> div = {"divide", "quotient", "over"};
|
| 434 |
-
static const std::vector<std::string> lt = {"less", "smaller", "lower"};
|
| 435 |
-
static const std::vector<std::string> le = {"at", "most", "no"};
|
| 436 |
-
static const std::vector<std::string> gt = {"greater", "larger", "higher", "more"};
|
| 437 |
-
static const std::vector<std::string> ge = {"at", "least", "no"};
|
| 438 |
-
static const std::vector<std::string> eq = {"equal", "equals", "same"};
|
| 439 |
-
|
| 440 |
-
switch (op) {
|
| 441 |
-
case MathOp::Add: return add;
|
| 442 |
-
case MathOp::Sub: return sub;
|
| 443 |
-
case MathOp::Mul: return mul;
|
| 444 |
-
case MathOp::Div: return div;
|
| 445 |
-
case MathOp::Lt: return lt;
|
| 446 |
-
case MathOp::Le: return le;
|
| 447 |
-
case MathOp::Gt: return gt;
|
| 448 |
-
case MathOp::Ge: return ge;
|
| 449 |
-
case MathOp::Eq: return eq;
|
| 450 |
-
default: return eq;
|
| 451 |
-
}
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
static std::unordered_set<std::string>
|
| 455 |
-
closure_set_from_tokens(const std::vector<std::string> &toks, int depth) {
|
| 456 |
-
std::unordered_set<std::string> acc;
|
| 457 |
-
std::unordered_set<std::string> frontier;
|
| 458 |
-
|
| 459 |
-
acc.reserve(toks.size() * 4 + 16);
|
| 460 |
-
frontier.reserve(toks.size() * 2 + 8);
|
| 461 |
-
|
| 462 |
-
for (const auto &tok : toks) {
|
| 463 |
-
const std::string k = normalize_dictionary_key(tok);
|
| 464 |
-
if (!k.empty() && acc.insert(k).second) frontier.insert(k);
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
for (int d = 0; d < depth && !frontier.empty(); ++d) {
|
| 468 |
-
std::unordered_set<std::string> next;
|
| 469 |
-
next.reserve(frontier.size() * 4 + 8);
|
| 470 |
-
|
| 471 |
-
for (const auto &k : frontier) {
|
| 472 |
-
auto it = global_def_tokens_cache.find(k);
|
| 473 |
-
if (it == global_def_tokens_cache.end()) continue;
|
| 474 |
-
|
| 475 |
-
for (const auto &w : it->second) {
|
| 476 |
-
const std::string kw = normalize_dictionary_key(w);
|
| 477 |
-
if (!kw.empty() && acc.insert(kw).second) next.insert(kw);
|
| 478 |
-
}
|
| 479 |
-
}
|
| 480 |
-
frontier.swap(next);
|
| 481 |
-
}
|
| 482 |
-
return acc;
|
| 483 |
-
}
|
| 484 |
-
|
| 485 |
-
static double jaccard_sets(const std::unordered_set<std::string> &a,
|
| 486 |
-
const std::unordered_set<std::string> &b) {
|
| 487 |
-
if (a.empty() && b.empty()) return 0.0;
|
| 488 |
-
|
| 489 |
-
const std::unordered_set<std::string> *sm = &a, *lg = &b;
|
| 490 |
-
if (a.size() > b.size()) std::swap(sm, lg);
|
| 491 |
-
|
| 492 |
-
size_t inter = 0;
|
| 493 |
-
for (const auto &x : *sm) {
|
| 494 |
-
if (lg->find(x) != lg->end()) ++inter;
|
| 495 |
-
}
|
| 496 |
-
const size_t uni = a.size() + b.size() - inter;
|
| 497 |
-
return uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
|
| 498 |
-
}
|
| 499 |
-
|
| 500 |
-
static double score_operator(const std::vector<std::string> &segment,
|
| 501 |
-
MathOp op,
|
| 502 |
-
int def_depth) {
|
| 503 |
-
const auto prompt_closure = closure_set_from_tokens(segment, def_depth);
|
| 504 |
-
const auto seed_closure = closure_set_from_tokens(op_seed_lexicon(op), def_depth);
|
| 505 |
-
|
| 506 |
-
double s = jaccard_sets(prompt_closure, seed_closure);
|
| 507 |
-
|
| 508 |
-
for (const auto &raw : segment) {
|
| 509 |
-
const std::string t = normalize_dictionary_key(raw);
|
| 510 |
-
|
| 511 |
-
if (op == MathOp::Add && is_add_seed(t)) s += 0.12;
|
| 512 |
-
if (op == MathOp::Sub && is_sub_seed(t)) s += 0.12;
|
| 513 |
-
if (op == MathOp::Mul && is_mul_seed(t)) s += 0.12;
|
| 514 |
-
if (op == MathOp::Div && is_div_seed(t)) s += 0.12;
|
| 515 |
-
if (op == MathOp::Lt && is_lt_seed(t)) s += 0.12;
|
| 516 |
-
if (op == MathOp::Le && is_le_seed(t)) s += 0.12;
|
| 517 |
-
if (op == MathOp::Gt && is_gt_seed(t)) s += 0.12;
|
| 518 |
-
if (op == MathOp::Ge && is_ge_seed(t)) s += 0.12;
|
| 519 |
-
if (op == MathOp::Eq && is_eq_seed(t)) s += 0.12;
|
| 520 |
-
}
|
| 521 |
-
return s;
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
-
static MathOp infer_dominant_operator(const std::vector<std::string> &segment, int def_depth) {
|
| 525 |
-
const double sa = score_operator(segment, MathOp::Add, def_depth);
|
| 526 |
-
const double ss = score_operator(segment, MathOp::Sub, def_depth);
|
| 527 |
-
const double sm = score_operator(segment, MathOp::Mul, def_depth);
|
| 528 |
-
const double sd = score_operator(segment, MathOp::Div, def_depth);
|
| 529 |
-
const double sl = score_operator(segment, MathOp::Lt, def_depth);
|
| 530 |
-
const double sle = score_operator(segment, MathOp::Le, def_depth);
|
| 531 |
-
const double sg = score_operator(segment, MathOp::Gt, def_depth);
|
| 532 |
-
const double sge = score_operator(segment, MathOp::Ge, def_depth);
|
| 533 |
-
const double se = score_operator(segment, MathOp::Eq, def_depth);
|
| 534 |
-
|
| 535 |
-
MathOp best = MathOp::Unknown;
|
| 536 |
-
double bestv = -1.0;
|
| 537 |
-
|
| 538 |
-
const std::pair<MathOp, double> arr[] = {
|
| 539 |
-
{MathOp::Add, sa}, {MathOp::Sub, ss}, {MathOp::Mul, sm}, {MathOp::Div, sd},
|
| 540 |
-
{MathOp::Lt, sl}, {MathOp::Le, sle}, {MathOp::Gt, sg}, {MathOp::Ge, sge},
|
| 541 |
-
{MathOp::Eq, se}
|
| 542 |
-
};
|
| 543 |
-
|
| 544 |
-
for (const auto &p : arr) {
|
| 545 |
-
if (p.second > bestv) {
|
| 546 |
-
bestv = p.second;
|
| 547 |
-
best = p.first;
|
| 548 |
-
}
|
| 549 |
-
}
|
| 550 |
-
return best;
|
| 551 |
-
}
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
static bool has_any_logic_token(const std::vector<std::string> &segment) {
|
| 555 |
-
for (const auto &raw : segment) {
|
| 556 |
-
if (is_logic_token(normalize_dictionary_key(raw))) return true;
|
| 557 |
-
}
|
| 558 |
-
return false;
|
| 559 |
-
}
|
| 560 |
-
|
| 561 |
-
static double fold_numbers_iterative(const std::vector<double> &nums, MathOp op, bool reverse_sub) {
|
| 562 |
-
if (nums.empty()) return 0.0;
|
| 563 |
-
if (nums.size() == 1) return nums[0];
|
| 564 |
-
|
| 565 |
-
if (op == MathOp::Add) {
|
| 566 |
-
double acc = 0.0;
|
| 567 |
-
for (double x : nums) acc += x;
|
| 568 |
-
return acc;
|
| 569 |
-
}
|
| 570 |
-
|
| 571 |
-
if (op == MathOp::Mul) {
|
| 572 |
-
double acc = 1.0;
|
| 573 |
-
for (double x : nums) acc *= x;
|
| 574 |
-
return acc;
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
if (op == MathOp::Div) {
|
| 578 |
-
double acc = nums[0];
|
| 579 |
-
for (size_t i = 1; i < nums.size(); ++i) {
|
| 580 |
-
if (nums[i] == 0.0) return std::numeric_limits<double>::quiet_NaN();
|
| 581 |
-
acc /= nums[i];
|
| 582 |
-
}
|
| 583 |
-
return acc;
|
| 584 |
-
}
|
| 585 |
-
|
| 586 |
-
if (op == MathOp::Sub) {
|
| 587 |
-
if (reverse_sub && nums.size() >= 2) {
|
| 588 |
-
double acc = nums.back();
|
| 589 |
-
for (size_t i = nums.size() - 1; i-- > 0;) {
|
| 590 |
-
acc -= nums[i];
|
| 591 |
-
}
|
| 592 |
-
return acc;
|
| 593 |
-
} else {
|
| 594 |
-
double acc = nums[0];
|
| 595 |
-
for (size_t i = 1; i < nums.size(); ++i) acc -= nums[i];
|
| 596 |
-
return acc;
|
| 597 |
-
}
|
| 598 |
-
}
|
| 599 |
-
|
| 600 |
-
// Default arithmetic fallback: sum.
|
| 601 |
-
double acc = 0.0;
|
| 602 |
-
for (double x : nums) acc += x;
|
| 603 |
-
return acc;
|
| 604 |
-
}
|
| 605 |
-
|
| 606 |
-
static std::optional<double>
|
| 607 |
-
evaluate_numeric_segment(const std::vector<std::string> &segment, int def_depth) {
|
| 608 |
-
std::vector<double> nums;
|
| 609 |
-
nums.reserve(segment.size());
|
| 610 |
-
|
| 611 |
-
for (const auto &raw : segment) {
|
| 612 |
-
const std::string surf = trim_math_surface(raw);
|
| 613 |
-
double v = 0.0;
|
| 614 |
-
if (is_number_surface(surf, v)) nums.push_back(v);
|
| 615 |
-
}
|
| 616 |
-
|
| 617 |
-
if (nums.empty()) return std::nullopt;
|
| 618 |
-
if (nums.size() == 1) return nums[0];
|
| 619 |
-
|
| 620 |
-
const MathOp op = infer_dominant_operator(segment, def_depth);
|
| 621 |
-
|
| 622 |
-
bool reverse_sub = false;
|
| 623 |
-
for (const auto &raw : segment) {
|
| 624 |
-
const std::string t = normalize_dictionary_key(raw);
|
| 625 |
-
if (is_reversal_marker(t)) {
|
| 626 |
-
reverse_sub = true;
|
| 627 |
-
break;
|
| 628 |
-
}
|
| 629 |
-
}
|
| 630 |
-
|
| 631 |
-
if (op == MathOp::Unknown) {
|
| 632 |
-
// Best-effort fallback: if there are multiple numerals and no clear cue,
|
| 633 |
-
// addition is the conservative default for many count/total prompts.
|
| 634 |
-
return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 635 |
-
}
|
| 636 |
-
|
| 637 |
-
if (op == MathOp::Sub) return fold_numbers_iterative(nums, MathOp::Sub, reverse_sub);
|
| 638 |
-
if (op == MathOp::Add) return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 639 |
-
if (op == MathOp::Mul) return fold_numbers_iterative(nums, MathOp::Mul, false);
|
| 640 |
-
if (op == MathOp::Div) return fold_numbers_iterative(nums, MathOp::Div, false);
|
| 641 |
-
|
| 642 |
-
// Numeric segment for comparison-like clauses: use the arithmetic value.
|
| 643 |
-
return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 644 |
-
}
|
| 645 |
-
|
| 646 |
-
static std::optional<bool>
|
| 647 |
-
evaluate_comparison_segment(const std::vector<std::string> &segment, int def_depth) {
|
| 648 |
-
size_t cue_pos = std::numeric_limits<size_t>::max();
|
| 649 |
-
MathOp cmp = MathOp::Unknown;
|
| 650 |
-
double best = -1.0;
|
| 651 |
-
|
| 652 |
-
for (size_t i = 0; i < segment.size(); ++i) {
|
| 653 |
-
const std::string t = trim_math_surface(segment[i]);
|
| 654 |
-
|
| 655 |
-
MathOp cand = MathOp::Unknown;
|
| 656 |
-
if (t == "<") cand = MathOp::Lt;
|
| 657 |
-
else if (t == "<=") cand = MathOp::Le;
|
| 658 |
-
else if (t == ">") cand = MathOp::Gt;
|
| 659 |
-
else if (t == ">=") cand = MathOp::Ge;
|
| 660 |
-
else if (t == "=" || t == "==") cand = MathOp::Eq;
|
| 661 |
-
else if (normalize_dictionary_key(segment[i]) == "less") cand = MathOp::Lt;
|
| 662 |
-
else if (normalize_dictionary_key(segment[i]) == "greater") cand = MathOp::Gt;
|
| 663 |
-
else if (normalize_dictionary_key(segment[i]) == "equal" || normalize_dictionary_key(segment[i]) == "equals")
|
| 664 |
-
cand = MathOp::Eq;
|
| 665 |
-
|
| 666 |
-
if (cand != MathOp::Unknown) {
|
| 667 |
-
const double s = score_operator(segment, cand, def_depth);
|
| 668 |
-
if (s > best) {
|
| 669 |
-
best = s;
|
| 670 |
-
cmp = cand;
|
| 671 |
-
cue_pos = i;
|
| 672 |
-
}
|
| 673 |
-
}
|
| 674 |
-
}
|
| 675 |
-
|
| 676 |
-
if (cmp == MathOp::Unknown || cue_pos == std::numeric_limits<size_t>::max()) return std::nullopt;
|
| 677 |
-
|
| 678 |
-
std::vector<std::string> left(segment.begin(), segment.begin() + static_cast<std::ptrdiff_t>(cue_pos));
|
| 679 |
-
std::vector<std::string> right(segment.begin() + static_cast<std::ptrdiff_t>(cue_pos) + 1, segment.end());
|
| 680 |
-
|
| 681 |
-
auto lv = evaluate_numeric_segment(left, def_depth);
|
| 682 |
-
auto rv = evaluate_numeric_segment(right, def_depth);
|
| 683 |
-
if (!lv.has_value() || !rv.has_value()) return std::nullopt;
|
| 684 |
-
|
| 685 |
-
switch (cmp) {
|
| 686 |
-
case MathOp::Lt: return *lv < *rv;
|
| 687 |
-
case MathOp::Le: return *lv <= *rv;
|
| 688 |
-
case MathOp::Gt: return *lv > *rv;
|
| 689 |
-
case MathOp::Ge: return *lv >= *rv;
|
| 690 |
-
case MathOp::Eq: return std::fabs(*lv - *rv) < 1e-12;
|
| 691 |
-
default: return std::nullopt;
|
| 692 |
-
}
|
| 693 |
-
}
|
| 694 |
-
|
| 695 |
-
static std::optional<bool>
|
| 696 |
-
evaluate_logic_chain(const std::vector<std::vector<std::string>> &segments, int def_depth) {
|
| 697 |
-
if (segments.empty()) return std::nullopt;
|
| 698 |
-
|
| 699 |
-
std::vector<bool> vals;
|
| 700 |
-
vals.reserve(segments.size());
|
| 701 |
-
std::vector<std::string> ops;
|
| 702 |
-
ops.reserve(segments.size());
|
| 703 |
-
|
| 704 |
-
for (const auto &seg : segments) {
|
| 705 |
-
bool negate = false;
|
| 706 |
-
size_t first_sig = 0;
|
| 707 |
-
while (first_sig < seg.size() &&
|
| 708 |
-
normalize_dictionary_key(seg[first_sig]).empty()) {
|
| 709 |
-
++first_sig;
|
| 710 |
-
}
|
| 711 |
-
if (first_sig < seg.size() && normalize_dictionary_key(seg[first_sig]) == "not") {
|
| 712 |
-
negate = true;
|
| 713 |
-
}
|
| 714 |
-
|
| 715 |
-
std::optional<bool> v = evaluate_comparison_segment(seg, def_depth);
|
| 716 |
-
if (!v.has_value()) {
|
| 717 |
-
auto n = evaluate_numeric_segment(seg, def_depth);
|
| 718 |
-
if (!n.has_value()) return std::nullopt;
|
| 719 |
-
v = (*n != 0.0);
|
| 720 |
-
}
|
| 721 |
-
|
| 722 |
-
bool b = *v;
|
| 723 |
-
if (negate) b = !b;
|
| 724 |
-
vals.push_back(b);
|
| 725 |
-
}
|
| 726 |
-
|
| 727 |
-
// If the original prompt had "and/or" tokens, combine them left-to-right.
|
| 728 |
-
// Otherwise, a single boolean clause is returned.
|
| 729 |
-
if (segments.size() == 1) return vals[0];
|
| 730 |
-
|
| 731 |
-
bool acc = vals[0];
|
| 732 |
-
size_t seg_idx = 1;
|
| 733 |
-
for (size_t i = 0; i < segments.size() - 1; ++i) {
|
| 734 |
-
// A simple left-to-right boolean fold. If the user wrote mixed logic,
|
| 735 |
-
// this is conservative and iterative.
|
| 736 |
-
const std::vector<std::string> &seg = segments[i];
|
| 737 |
-
bool saw_or = false;
|
| 738 |
-
bool saw_and = false;
|
| 739 |
-
for (const auto &raw : seg) {
|
| 740 |
-
const std::string t = normalize_dictionary_key(raw);
|
| 741 |
-
if (t == "or") saw_or = true;
|
| 742 |
-
if (t == "and") saw_and = true;
|
| 743 |
-
}
|
| 744 |
-
if (saw_or && !saw_and) acc = acc || vals[seg_idx++];
|
| 745 |
-
else acc = acc && vals[seg_idx++];
|
| 746 |
-
}
|
| 747 |
-
|
| 748 |
-
return acc;
|
| 749 |
-
}
|
| 750 |
-
|
| 751 |
-
static std::vector<std::vector<std::string>>
|
| 752 |
-
split_into_math_segments(const std::vector<std::string> &prompt_toks) {
|
| 753 |
-
std::vector<std::vector<std::string>> segs;
|
| 754 |
-
std::vector<std::string> cur;
|
| 755 |
-
segs.reserve(8);
|
| 756 |
-
|
| 757 |
-
for (const auto &tok : prompt_toks) {
|
| 758 |
-
const std::string t = normalize_dictionary_key(tok);
|
| 759 |
-
|
| 760 |
-
if (is_sentence_break_token(tok)) {
|
| 761 |
-
if (!cur.empty()) {
|
| 762 |
-
segs.push_back(std::move(cur));
|
| 763 |
-
cur.clear();
|
| 764 |
-
}
|
| 765 |
-
continue;
|
| 766 |
-
}
|
| 767 |
-
|
| 768 |
-
if (t == "then" || t == "after" || t == "before") {
|
| 769 |
-
if (!cur.empty()) {
|
| 770 |
-
segs.push_back(std::move(cur));
|
| 771 |
-
cur.clear();
|
| 772 |
-
}
|
| 773 |
-
continue;
|
| 774 |
-
}
|
| 775 |
-
|
| 776 |
-
cur.push_back(tok);
|
| 777 |
-
}
|
| 778 |
-
|
| 779 |
-
if (!cur.empty()) segs.push_back(std::move(cur));
|
| 780 |
-
return segs;
|
| 781 |
-
}
|
| 782 |
-
|
| 783 |
-
static std::optional<std::string>
|
| 784 |
-
try_math_branch(const std::vector<std::string> &prompt_toks, int def_depth) {
|
| 785 |
-
if (prompt_toks.empty()) return std::nullopt;
|
| 786 |
-
|
| 787 |
-
auto segments = split_into_math_segments(prompt_toks);
|
| 788 |
-
if (segments.empty()) return std::nullopt;
|
| 789 |
-
|
| 790 |
-
// If any segment contains logic/comparison language, try boolean evaluation.
|
| 791 |
-
bool has_logic = false;
|
| 792 |
-
bool has_cmp = false;
|
| 793 |
-
for (const auto &seg : segments) {
|
| 794 |
-
if (has_any_logic_token(seg)) has_logic = true;
|
| 795 |
-
if (evaluate_comparison_segment(seg, def_depth).has_value()) has_cmp = true;
|
| 796 |
-
}
|
| 797 |
-
|
| 798 |
-
if (has_logic || has_cmp) {
|
| 799 |
-
auto b = evaluate_logic_chain(segments, def_depth);
|
| 800 |
-
if (b.has_value()) {
|
| 801 |
-
const std::string out = *b ? "True" : "False";
|
| 802 |
-
std::cout << out << ' ' << std::flush;
|
| 803 |
-
return out;
|
| 804 |
-
}
|
| 805 |
-
}
|
| 806 |
-
|
| 807 |
-
// Otherwise choose the strongest arithmetic-looking segment.
|
| 808 |
-
double best_score = -1.0;
|
| 809 |
-
const std::vector<std::string> *best_seg = nullptr;
|
| 810 |
-
|
| 811 |
-
for (const auto &seg : segments) {
|
| 812 |
-
const MathOp op = infer_dominant_operator(seg, def_depth);
|
| 813 |
-
double s = score_operator(seg, op, def_depth);
|
| 814 |
-
|
| 815 |
-
// Numeric content increases confidence.
|
| 816 |
-
size_t num_count = 0;
|
| 817 |
-
for (const auto &raw : seg) {
|
| 818 |
-
double v = 0.0;
|
| 819 |
-
if (is_number_surface(trim_math_surface(raw), v)) ++num_count;
|
| 820 |
-
}
|
| 821 |
-
s += 0.03 * static_cast<double>(num_count);
|
| 822 |
-
|
| 823 |
-
if (s > best_score) {
|
| 824 |
-
best_score = s;
|
| 825 |
-
best_seg = &seg;
|
| 826 |
-
}
|
| 827 |
-
}
|
| 828 |
-
|
| 829 |
-
if (!best_seg || best_score < 0.12) return std::nullopt;
|
| 830 |
-
|
| 831 |
-
auto comp = evaluate_comparison_segment(*best_seg, def_depth);
|
| 832 |
-
if (comp.has_value()) return *comp ? std::string("True") : std::string("False");
|
| 833 |
-
|
| 834 |
-
auto num = evaluate_numeric_segment(*best_seg, def_depth);
|
| 835 |
-
if (num.has_value()) {
|
| 836 |
-
std::ostringstream oss;
|
| 837 |
-
const double rounded = std::round(*num);
|
| 838 |
-
if (std::fabs(*num - rounded) < 1e-12) oss << static_cast<long long>(rounded);
|
| 839 |
-
else oss << *num;
|
| 840 |
-
|
| 841 |
-
const std::string out = oss.str();
|
| 842 |
-
std::cout << out << ' ' << std::flush;
|
| 843 |
-
return out;
|
| 844 |
-
}
|
| 845 |
-
|
| 846 |
-
std::ostringstream oss;
|
| 847 |
-
const double rounded = std::round(*num);
|
| 848 |
-
if (std::fabs(*num - rounded) < 1e-12) oss << static_cast<long long>(rounded);
|
| 849 |
-
else oss << *num;
|
| 850 |
-
return oss.str();
|
| 851 |
-
}
|
| 852 |
-
|
| 853 |
-
// --------------------------- String interning (short methods) --------------
|
| 854 |
-
|
| 855 |
using StrPtr = const std::string*;
|
| 856 |
using TokenId = std::uint32_t;
|
| 857 |
static constexpr TokenId TOKEN_ID_INVALID = 0xFFFFFFFFu;
|
|
@@ -925,7 +351,6 @@ static inline std::size_t bitset_intersection_count(const std::uint64_t *a, cons
|
|
| 925 |
return total;
|
| 926 |
}
|
| 927 |
|
| 928 |
-
// ---------- Global parsed dictionary (populated once in main) ----------
|
| 929 |
static void build_def_tokens_cache(){
|
| 930 |
global_def_tokens_cache.clear();
|
| 931 |
global_pos_cache.clear();
|
|
@@ -960,7 +385,6 @@ static void build_def_tokens_cache(){
|
|
| 960 |
}
|
| 961 |
}
|
| 962 |
|
| 963 |
-
// --------------------------- Knowledge base (short methods) --------------
|
| 964 |
struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
|
| 965 |
struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
|
| 966 |
|
|
@@ -972,7 +396,6 @@ struct KnowledgeBase {
|
|
| 972 |
std::unordered_map<std::string, StrPtr> next_key_index;
|
| 973 |
mutable std::mutex m;
|
| 974 |
|
| 975 |
-
// def-index: for each interned word pointer -> list of interned tokens (definition expansion)
|
| 976 |
std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
|
| 977 |
mutable std::mutex def_m;
|
| 978 |
int def_depth = 0;
|
|
@@ -985,7 +408,6 @@ struct KnowledgeBase {
|
|
| 985 |
vec.push_back(v);
|
| 986 |
}
|
| 987 |
|
| 988 |
-
// set def depth; if changed, drop previously computed def expansions
|
| 989 |
void set_def_depth(int D){
|
| 990 |
std::lock_guard<std::mutex> lk(def_m);
|
| 991 |
if (D != def_depth){
|
|
@@ -1045,11 +467,10 @@ struct KnowledgeBase {
|
|
| 1045 |
def_index.emplace(wp, std::move(out));
|
| 1046 |
}
|
| 1047 |
}
|
| 1048 |
-
|
| 1049 |
void add_pair(const std::string &k, const std::string &v){
|
| 1050 |
StrPtr kp = interner.intern(k);
|
| 1051 |
StrPtr vp = interner.intern(v);
|
| 1052 |
-
// ensure definition expansion for both words as soon as they are seen
|
| 1053 |
ensure_def_for_interned(kp);
|
| 1054 |
ensure_def_for_interned(vp);
|
| 1055 |
add_pair_interned(kp, vp);
|
|
@@ -1081,8 +502,6 @@ intern_tokens(KnowledgeBase &kb, const std::vector<std::string> &tokens)
|
|
| 1081 |
return out;
|
| 1082 |
}
|
| 1083 |
|
| 1084 |
-
// --------------------------- Small JSON parse helpers ----------------------
|
| 1085 |
-
|
| 1086 |
static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
|
| 1087 |
|
| 1088 |
static std::string parse_quoted_string(const std::string &text, size_t &i){
|
|
@@ -1107,7 +526,6 @@ static void skip_spaces(const std::string &s, size_t &i){
|
|
| 1107 |
while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
-
// Very small JSON-like parser tailored to dictionary_json structure
|
| 1111 |
static void skip_json_value(const std::string &s, size_t &i);
|
| 1112 |
|
| 1113 |
static std::vector<std::string> parse_json_string_array(const std::string &text, size_t &i){
|
|
@@ -1381,10 +799,6 @@ static std::vector<std::string> construct_response(KnowledgeBase &kb,
|
|
| 1381 |
std::vector<std::string> resp;
|
| 1382 |
if (prompt_toks.empty() || maxlen == 0) return resp;
|
| 1383 |
|
| 1384 |
-
if (auto math_out = try_math_branch(prompt_toks, static_cast<int>(kb.def_depth)); math_out.has_value()) {
|
| 1385 |
-
return std::vector<std::string>{*math_out};
|
| 1386 |
-
}
|
| 1387 |
-
|
| 1388 |
auto prompt_ptrs = intern_tokens(kb, prompt_toks);
|
| 1389 |
std::vector<StrPtr> resp_ptrs;
|
| 1390 |
std::unordered_map<std::string,int> recent_counts;
|
|
@@ -1467,8 +881,6 @@ static std::vector<std::string> construct_response(KnowledgeBase &kb,
|
|
| 1467 |
return resp;
|
| 1468 |
}
|
| 1469 |
|
| 1470 |
-
// --------------------------- Learning from files (short) -------------------
|
| 1471 |
-
|
| 1472 |
static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
|
| 1473 |
std::ifstream ifs(fname);
|
| 1474 |
if (!ifs) return;
|
|
@@ -1486,8 +898,7 @@ static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::strin
|
|
| 1486 |
for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
|
| 1487 |
}
|
| 1488 |
|
| 1489 |
-
|
| 1490 |
-
static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL; // "CHPCKSVA"
|
| 1491 |
static constexpr std::uint64_t KB_VERSION = 1ULL;
|
| 1492 |
|
| 1493 |
static void write_u64(std::ostream &os, std::uint64_t v){
|
|
@@ -1679,7 +1090,6 @@ static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_
|
|
| 1679 |
}
|
| 1680 |
}
|
| 1681 |
|
| 1682 |
-
// If the caller asks for a different dict depth, recompute with the current embedded dictionary.
|
| 1683 |
if (cli_dict_depth != static_cast<int>(file_def_depth)){
|
| 1684 |
kb.set_def_depth(cli_dict_depth);
|
| 1685 |
|
|
@@ -1710,8 +1120,6 @@ static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_
|
|
| 1710 |
}
|
| 1711 |
}
|
| 1712 |
|
| 1713 |
-
// --------------------------- CLI + Interactive loop (shorters) -----------
|
| 1714 |
-
|
| 1715 |
static void print_usage(const char *p){
|
| 1716 |
std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
|
| 1717 |
std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
|
|
@@ -1755,13 +1163,10 @@ int main(int argc, char **argv){
|
|
| 1755 |
|
| 1756 |
KnowledgeBase kb;
|
| 1757 |
|
| 1758 |
-
// parse the embedded dictionary once for use by per-word expansion
|
| 1759 |
global_dictionary_entries = parse_dictionary_json();
|
| 1760 |
build_def_tokens_cache();
|
| 1761 |
-
// set KB def depth (clears any previous expansion)
|
| 1762 |
kb.set_def_depth(dict_depth);
|
| 1763 |
|
| 1764 |
-
|
| 1765 |
if (!load_kb.empty()){
|
| 1766 |
try { std::cerr << "Loading KB: " << load_kb << "\n";
|
| 1767 |
load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
|
|
@@ -1781,11 +1186,7 @@ int main(int argc, char **argv){
|
|
| 1781 |
for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
|
| 1782 |
auto resp = construct_response(kb, prompt_toks, maxlen, repeat_penalty);
|
| 1783 |
std::cout << "\n";
|
| 1784 |
-
if (!resp.empty()){
|
| 1785 |
-
std::vector<std::string> combined = prompt_toks;
|
| 1786 |
-
combined.insert(combined.end(), resp.begin(), resp.end());
|
| 1787 |
-
for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
|
| 1788 |
-
}
|
| 1789 |
if (!savefile.empty()){
|
| 1790 |
try { std::cerr << "Saving KB: " << savefile << "\n";
|
| 1791 |
save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
|
|
|
|
|
|
|
|
|
|
| 1 |
#include <algorithm>
|
| 2 |
#include <atomic>
|
| 3 |
#include <cctype>
|
|
|
|
| 16 |
#include <unordered_map>
|
| 17 |
#include <unordered_set>
|
| 18 |
#include <vector>
|
|
|
|
|
|
|
| 19 |
|
| 20 |
#ifdef _OPENMP
|
| 21 |
#include <omp.h>
|
|
|
|
| 24 |
inline int omp_get_thread_num(){ return 0; }
|
| 25 |
#endif
|
| 26 |
|
| 27 |
+
extern unsigned char dictionary_json[];
|
| 28 |
extern unsigned int dictionary_json_len;
|
| 29 |
|
|
|
|
| 30 |
|
| 31 |
static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
|
| 32 |
static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
|
| 33 |
static inline void safe_flush(std::ostream &os){ os.flush(); }
|
| 34 |
|
|
|
|
| 35 |
struct DictionaryEntry {
|
| 36 |
std::string pos;
|
| 37 |
std::string word;
|
|
|
|
| 257 |
return bonus;
|
| 258 |
}
|
| 259 |
|
|
|
|
| 260 |
static std::vector<std::string> tokenize_whitespace(const std::string &s){
|
| 261 |
std::istringstream iss(s);
|
| 262 |
std::vector<std::string> out;
|
|
|
|
| 265 |
return out;
|
| 266 |
}
|
| 267 |
|
|
|
|
| 268 |
static std::vector<std::string> tokenize_non_alnum(const std::string &s){
|
| 269 |
std::vector<std::string> out; std::string cur;
|
| 270 |
for (char ch : s){
|
|
|
|
| 278 |
return out;
|
| 279 |
}
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
using StrPtr = const std::string*;
|
| 282 |
using TokenId = std::uint32_t;
|
| 283 |
static constexpr TokenId TOKEN_ID_INVALID = 0xFFFFFFFFu;
|
|
|
|
| 351 |
return total;
|
| 352 |
}
|
| 353 |
|
|
|
|
| 354 |
static void build_def_tokens_cache(){
|
| 355 |
global_def_tokens_cache.clear();
|
| 356 |
global_pos_cache.clear();
|
|
|
|
| 385 |
}
|
| 386 |
}
|
| 387 |
|
|
|
|
| 388 |
struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<StrPtr>()(p); } };
|
| 389 |
struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return a == b; } };
|
| 390 |
|
|
|
|
| 396 |
std::unordered_map<std::string, StrPtr> next_key_index;
|
| 397 |
mutable std::mutex m;
|
| 398 |
|
|
|
|
| 399 |
std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
|
| 400 |
mutable std::mutex def_m;
|
| 401 |
int def_depth = 0;
|
|
|
|
| 408 |
vec.push_back(v);
|
| 409 |
}
|
| 410 |
|
|
|
|
| 411 |
void set_def_depth(int D){
|
| 412 |
std::lock_guard<std::mutex> lk(def_m);
|
| 413 |
if (D != def_depth){
|
|
|
|
| 467 |
def_index.emplace(wp, std::move(out));
|
| 468 |
}
|
| 469 |
}
|
| 470 |
+
|
| 471 |
void add_pair(const std::string &k, const std::string &v){
|
| 472 |
StrPtr kp = interner.intern(k);
|
| 473 |
StrPtr vp = interner.intern(v);
|
|
|
|
| 474 |
ensure_def_for_interned(kp);
|
| 475 |
ensure_def_for_interned(vp);
|
| 476 |
add_pair_interned(kp, vp);
|
|
|
|
| 502 |
return out;
|
| 503 |
}
|
| 504 |
|
|
|
|
|
|
|
| 505 |
static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
|
| 506 |
|
| 507 |
static std::string parse_quoted_string(const std::string &text, size_t &i){
|
|
|
|
| 526 |
while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
|
| 527 |
}
|
| 528 |
|
|
|
|
| 529 |
static void skip_json_value(const std::string &s, size_t &i);
|
| 530 |
|
| 531 |
static std::vector<std::string> parse_json_string_array(const std::string &text, size_t &i){
|
|
|
|
| 799 |
std::vector<std::string> resp;
|
| 800 |
if (prompt_toks.empty() || maxlen == 0) return resp;
|
| 801 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
auto prompt_ptrs = intern_tokens(kb, prompt_toks);
|
| 803 |
std::vector<StrPtr> resp_ptrs;
|
| 804 |
std::unordered_map<std::string,int> recent_counts;
|
|
|
|
| 881 |
return resp;
|
| 882 |
}
|
| 883 |
|
|
|
|
|
|
|
| 884 |
static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
|
| 885 |
std::ifstream ifs(fname);
|
| 886 |
if (!ifs) return;
|
|
|
|
| 898 |
for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
|
| 899 |
}
|
| 900 |
|
| 901 |
+
static constexpr std::uint64_t KB_MAGIC = 0x434850434B535641ULL;
|
|
|
|
| 902 |
static constexpr std::uint64_t KB_VERSION = 1ULL;
|
| 903 |
|
| 904 |
static void write_u64(std::ostream &os, std::uint64_t v){
|
|
|
|
| 1090 |
}
|
| 1091 |
}
|
| 1092 |
|
|
|
|
| 1093 |
if (cli_dict_depth != static_cast<int>(file_def_depth)){
|
| 1094 |
kb.set_def_depth(cli_dict_depth);
|
| 1095 |
|
|
|
|
| 1120 |
}
|
| 1121 |
}
|
| 1122 |
|
|
|
|
|
|
|
| 1123 |
static void print_usage(const char *p){
|
| 1124 |
std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
|
| 1125 |
std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
|
|
|
|
| 1163 |
|
| 1164 |
KnowledgeBase kb;
|
| 1165 |
|
|
|
|
| 1166 |
global_dictionary_entries = parse_dictionary_json();
|
| 1167 |
build_def_tokens_cache();
|
|
|
|
| 1168 |
kb.set_def_depth(dict_depth);
|
| 1169 |
|
|
|
|
| 1170 |
if (!load_kb.empty()){
|
| 1171 |
try { std::cerr << "Loading KB: " << load_kb << "\n";
|
| 1172 |
load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
|
|
|
|
| 1186 |
for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
|
| 1187 |
auto resp = construct_response(kb, prompt_toks, maxlen, repeat_penalty);
|
| 1188 |
std::cout << "\n";
|
| 1189 |
+
if (!resp.empty()){for (size_t i=1;i<resp.size();++i) kb.add_pair(resp[i-1], resp[i]);}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
if (!savefile.empty()){
|
| 1191 |
try { std::cerr << "Saving KB: " << savefile << "\n";
|
| 1192 |
save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
|