Commit ·
fc85087
1
Parent(s): c84e03e
Upload ChatIPC.cpp
Browse files- ChatIPC.cpp +570 -0
ChatIPC.cpp
CHANGED
|
@@ -18,6 +18,8 @@
|
|
| 18 |
#include <unordered_map>
|
| 19 |
#include <unordered_set>
|
| 20 |
#include <vector>
|
|
|
|
|
|
|
| 21 |
|
| 22 |
#ifdef _OPENMP
|
| 23 |
#include <omp.h>
|
|
@@ -284,6 +286,570 @@ static std::vector<std::string> tokenize_non_alnum(const std::string &s){
|
|
| 284 |
return out;
|
| 285 |
}
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
// --------------------------- String interning (short methods) --------------
|
| 288 |
|
| 289 |
using StrPtr = const std::string*;
|
|
@@ -815,6 +1381,10 @@ static std::vector<std::string> construct_response(KnowledgeBase &kb,
|
|
| 815 |
std::vector<std::string> resp;
|
| 816 |
if (prompt_toks.empty() || maxlen == 0) return resp;
|
| 817 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
auto prompt_ptrs = intern_tokens(kb, prompt_toks);
|
| 819 |
std::vector<StrPtr> resp_ptrs;
|
| 820 |
std::unordered_map<std::string,int> recent_counts;
|
|
|
|
| 18 |
#include <unordered_map>
|
| 19 |
#include <unordered_set>
|
| 20 |
#include <vector>
|
| 21 |
+
#include <cmath>
|
| 22 |
+
#include <limits>
|
| 23 |
|
| 24 |
#ifdef _OPENMP
|
| 25 |
#include <omp.h>
|
|
|
|
| 286 |
return out;
|
| 287 |
}
|
| 288 |
|
| 289 |
+
// ---------------- Math feature: iterative only ----------------
|
| 290 |
+
|
| 291 |
+
enum class MathOp {
|
| 292 |
+
Unknown,
|
| 293 |
+
Add,
|
| 294 |
+
Sub,
|
| 295 |
+
Mul,
|
| 296 |
+
Div,
|
| 297 |
+
Lt,
|
| 298 |
+
Le,
|
| 299 |
+
Gt,
|
| 300 |
+
Ge,
|
| 301 |
+
Eq
|
| 302 |
+
};
|
| 303 |
+
|
| 304 |
+
struct MathValue {
|
| 305 |
+
bool valid = false;
|
| 306 |
+
bool is_bool = false;
|
| 307 |
+
double number = 0.0;
|
| 308 |
+
bool boolean = false;
|
| 309 |
+
};
|
| 310 |
+
|
| 311 |
+
static inline std::string lower_copy_str(const std::string &s) {
|
| 312 |
+
std::string out;
|
| 313 |
+
out.reserve(s.size());
|
| 314 |
+
for (char c : s) out.push_back(to_low(c));
|
| 315 |
+
return out;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
static inline std::string trim_math_surface(const std::string &s) {
|
| 319 |
+
size_t b = 0, e = s.size();
|
| 320 |
+
while (b < e && std::isspace(static_cast<unsigned char>(s[b]))) ++b;
|
| 321 |
+
while (e > b && std::isspace(static_cast<unsigned char>(s[e - 1]))) --e;
|
| 322 |
+
|
| 323 |
+
while (b < e) {
|
| 324 |
+
unsigned char uc = static_cast<unsigned char>(s[b]);
|
| 325 |
+
char c = s[b];
|
| 326 |
+
if (!std::ispunct(uc) || c == '<' || c == '>' || c == '=' || c == '+' ||
|
| 327 |
+
c == '-' || c == '*' || c == '/' || c == '(' || c == ')')
|
| 328 |
+
break;
|
| 329 |
+
++b;
|
| 330 |
+
}
|
| 331 |
+
while (e > b) {
|
| 332 |
+
unsigned char uc = static_cast<unsigned char>(s[e - 1]);
|
| 333 |
+
char c = s[e - 1];
|
| 334 |
+
if (!std::ispunct(uc) || c == '<' || c == '>' || c == '=' || c == '+' ||
|
| 335 |
+
c == '-' || c == '*' || c == '/' || c == '(' || c == ')')
|
| 336 |
+
break;
|
| 337 |
+
--e;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
std::string out;
|
| 341 |
+
out.reserve(e - b);
|
| 342 |
+
for (size_t i = b; i < e; ++i) out.push_back(to_low(s[i]));
|
| 343 |
+
return out;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
static inline bool is_number_surface(const std::string &s, double &value) {
|
| 347 |
+
if (s.empty()) return false;
|
| 348 |
+
bool has_digit = false;
|
| 349 |
+
bool has_dot = false;
|
| 350 |
+
|
| 351 |
+
for (char c : s) {
|
| 352 |
+
if (std::isdigit(static_cast<unsigned char>(c))) {
|
| 353 |
+
has_digit = true;
|
| 354 |
+
continue;
|
| 355 |
+
}
|
| 356 |
+
if (c == '.' && !has_dot) {
|
| 357 |
+
has_dot = true;
|
| 358 |
+
continue;
|
| 359 |
+
}
|
| 360 |
+
if ((c == '+' || c == '-') && &c == &s.front()) continue;
|
| 361 |
+
return false;
|
| 362 |
+
}
|
| 363 |
+
if (!has_digit) return false;
|
| 364 |
+
|
| 365 |
+
try {
|
| 366 |
+
size_t idx = 0;
|
| 367 |
+
value = std::stod(s, &idx);
|
| 368 |
+
return idx == s.size();
|
| 369 |
+
} catch (...) {
|
| 370 |
+
return false;
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
static inline bool is_sentence_break_token(const std::string &raw) {
|
| 375 |
+
if (raw.empty()) return false;
|
| 376 |
+
const char c = raw.back();
|
| 377 |
+
return c == '.' || c == '?' || c == '!' || c == ';' || c == ':';
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
static inline bool is_logic_token(const std::string &t) {
|
| 381 |
+
return t == "and" || t == "or" || t == "not";
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
static inline bool is_comparison_surface(const std::string &t) {
|
| 385 |
+
return t == "<" || t == "<=" || t == ">" || t == ">=" || t == "=" || t == "==";
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
static inline bool is_reversal_marker(const std::string &t) {
|
| 389 |
+
return t == "from" || t == "than" || t == "by";
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
static inline bool is_add_seed(const std::string &t) {
|
| 393 |
+
return t == "add" || t == "plus" || t == "sum";
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
static inline bool is_sub_seed(const std::string &t) {
|
| 397 |
+
return t == "subtract" || t == "minus" || t == "difference";
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
static inline bool is_mul_seed(const std::string &t) {
|
| 401 |
+
return t == "multiply" || t == "times" || t == "product";
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
static inline bool is_div_seed(const std::string &t) {
|
| 405 |
+
return t == "divide" || t == "quotient" || t == "over";
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
static inline bool is_lt_seed(const std::string &t) {
|
| 409 |
+
return t == "less" || t == "smaller" || t == "lower";
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
static inline bool is_gt_seed(const std::string &t) {
|
| 413 |
+
return t == "greater" || t == "larger" || t == "higher" || t == "more";
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
static inline bool is_le_seed(const std::string &t) {
|
| 417 |
+
return t == "atmost" || t == "at" || t == "no";
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
static inline bool is_ge_seed(const std::string &t) {
|
| 421 |
+
return t == "atleast";
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
static inline bool is_eq_seed(const std::string &t) {
|
| 425 |
+
return t == "equal" || t == "equals" || t == "same";
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// Minimal bootstrap seeds only; dictionary closure does the expansion.
|
| 429 |
+
static const std::vector<std::string> &op_seed_lexicon(MathOp op) {
|
| 430 |
+
static const std::vector<std::string> add = {"add", "plus", "sum"};
|
| 431 |
+
static const std::vector<std::string> sub = {"subtract", "minus", "difference"};
|
| 432 |
+
static const std::vector<std::string> mul = {"multiply", "times", "product"};
|
| 433 |
+
static const std::vector<std::string> div = {"divide", "quotient", "over"};
|
| 434 |
+
static const std::vector<std::string> lt = {"less", "smaller", "lower"};
|
| 435 |
+
static const std::vector<std::string> le = {"at", "most", "no"};
|
| 436 |
+
static const std::vector<std::string> gt = {"greater", "larger", "higher", "more"};
|
| 437 |
+
static const std::vector<std::string> ge = {"at", "least", "no"};
|
| 438 |
+
static const std::vector<std::string> eq = {"equal", "equals", "same"};
|
| 439 |
+
|
| 440 |
+
switch (op) {
|
| 441 |
+
case MathOp::Add: return add;
|
| 442 |
+
case MathOp::Sub: return sub;
|
| 443 |
+
case MathOp::Mul: return mul;
|
| 444 |
+
case MathOp::Div: return div;
|
| 445 |
+
case MathOp::Lt: return lt;
|
| 446 |
+
case MathOp::Le: return le;
|
| 447 |
+
case MathOp::Gt: return gt;
|
| 448 |
+
case MathOp::Ge: return ge;
|
| 449 |
+
case MathOp::Eq: return eq;
|
| 450 |
+
default: return eq;
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
static std::unordered_set<std::string>
|
| 455 |
+
closure_set_from_tokens(const std::vector<std::string> &toks, int depth) {
|
| 456 |
+
std::unordered_set<std::string> acc;
|
| 457 |
+
std::unordered_set<std::string> frontier;
|
| 458 |
+
|
| 459 |
+
acc.reserve(toks.size() * 4 + 16);
|
| 460 |
+
frontier.reserve(toks.size() * 2 + 8);
|
| 461 |
+
|
| 462 |
+
for (const auto &tok : toks) {
|
| 463 |
+
const std::string k = normalize_dictionary_key(tok);
|
| 464 |
+
if (!k.empty() && acc.insert(k).second) frontier.insert(k);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
for (int d = 0; d < depth && !frontier.empty(); ++d) {
|
| 468 |
+
std::unordered_set<std::string> next;
|
| 469 |
+
next.reserve(frontier.size() * 4 + 8);
|
| 470 |
+
|
| 471 |
+
for (const auto &k : frontier) {
|
| 472 |
+
auto it = global_def_tokens_cache.find(k);
|
| 473 |
+
if (it == global_def_tokens_cache.end()) continue;
|
| 474 |
+
|
| 475 |
+
for (const auto &w : it->second) {
|
| 476 |
+
const std::string kw = normalize_dictionary_key(w);
|
| 477 |
+
if (!kw.empty() && acc.insert(kw).second) next.insert(kw);
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
frontier.swap(next);
|
| 481 |
+
}
|
| 482 |
+
return acc;
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
static double jaccard_sets(const std::unordered_set<std::string> &a,
|
| 486 |
+
const std::unordered_set<std::string> &b) {
|
| 487 |
+
if (a.empty() && b.empty()) return 0.0;
|
| 488 |
+
|
| 489 |
+
const std::unordered_set<std::string> *sm = &a, *lg = &b;
|
| 490 |
+
if (a.size() > b.size()) std::swap(sm, lg);
|
| 491 |
+
|
| 492 |
+
size_t inter = 0;
|
| 493 |
+
for (const auto &x : *sm) {
|
| 494 |
+
if (lg->find(x) != lg->end()) ++inter;
|
| 495 |
+
}
|
| 496 |
+
const size_t uni = a.size() + b.size() - inter;
|
| 497 |
+
return uni ? static_cast<double>(inter) / static_cast<double>(uni) : 0.0;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
static double score_operator(const std::vector<std::string> &segment,
|
| 501 |
+
MathOp op,
|
| 502 |
+
int def_depth) {
|
| 503 |
+
const auto prompt_closure = closure_set_from_tokens(segment, def_depth);
|
| 504 |
+
const auto seed_closure = closure_set_from_tokens(op_seed_lexicon(op), def_depth);
|
| 505 |
+
|
| 506 |
+
double s = jaccard_sets(prompt_closure, seed_closure);
|
| 507 |
+
|
| 508 |
+
for (const auto &raw : segment) {
|
| 509 |
+
const std::string t = normalize_dictionary_key(raw);
|
| 510 |
+
|
| 511 |
+
if (op == MathOp::Add && is_add_seed(t)) s += 0.12;
|
| 512 |
+
if (op == MathOp::Sub && is_sub_seed(t)) s += 0.12;
|
| 513 |
+
if (op == MathOp::Mul && is_mul_seed(t)) s += 0.12;
|
| 514 |
+
if (op == MathOp::Div && is_div_seed(t)) s += 0.12;
|
| 515 |
+
if (op == MathOp::Lt && is_lt_seed(t)) s += 0.12;
|
| 516 |
+
if (op == MathOp::Le && is_le_seed(t)) s += 0.12;
|
| 517 |
+
if (op == MathOp::Gt && is_gt_seed(t)) s += 0.12;
|
| 518 |
+
if (op == MathOp::Ge && is_ge_seed(t)) s += 0.12;
|
| 519 |
+
if (op == MathOp::Eq && is_eq_seed(t)) s += 0.12;
|
| 520 |
+
}
|
| 521 |
+
return s;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
static MathOp infer_dominant_operator(const std::vector<std::string> &segment, int def_depth) {
|
| 525 |
+
const double sa = score_operator(segment, MathOp::Add, def_depth);
|
| 526 |
+
const double ss = score_operator(segment, MathOp::Sub, def_depth);
|
| 527 |
+
const double sm = score_operator(segment, MathOp::Mul, def_depth);
|
| 528 |
+
const double sd = score_operator(segment, MathOp::Div, def_depth);
|
| 529 |
+
const double sl = score_operator(segment, MathOp::Lt, def_depth);
|
| 530 |
+
const double sle = score_operator(segment, MathOp::Le, def_depth);
|
| 531 |
+
const double sg = score_operator(segment, MathOp::Gt, def_depth);
|
| 532 |
+
const double sge = score_operator(segment, MathOp::Ge, def_depth);
|
| 533 |
+
const double se = score_operator(segment, MathOp::Eq, def_depth);
|
| 534 |
+
|
| 535 |
+
MathOp best = MathOp::Unknown;
|
| 536 |
+
double bestv = -1.0;
|
| 537 |
+
|
| 538 |
+
const std::pair<MathOp, double> arr[] = {
|
| 539 |
+
{MathOp::Add, sa}, {MathOp::Sub, ss}, {MathOp::Mul, sm}, {MathOp::Div, sd},
|
| 540 |
+
{MathOp::Lt, sl}, {MathOp::Le, sle}, {MathOp::Gt, sg}, {MathOp::Ge, sge},
|
| 541 |
+
{MathOp::Eq, se}
|
| 542 |
+
};
|
| 543 |
+
|
| 544 |
+
for (const auto &p : arr) {
|
| 545 |
+
if (p.second > bestv) {
|
| 546 |
+
bestv = p.second;
|
| 547 |
+
best = p.first;
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
return best;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
static bool has_any_logic_token(const std::vector<std::string> &segment) {
|
| 555 |
+
for (const auto &raw : segment) {
|
| 556 |
+
if (is_logic_token(normalize_dictionary_key(raw))) return true;
|
| 557 |
+
}
|
| 558 |
+
return false;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
static double fold_numbers_iterative(const std::vector<double> &nums, MathOp op, bool reverse_sub) {
|
| 562 |
+
if (nums.empty()) return 0.0;
|
| 563 |
+
if (nums.size() == 1) return nums[0];
|
| 564 |
+
|
| 565 |
+
if (op == MathOp::Add) {
|
| 566 |
+
double acc = 0.0;
|
| 567 |
+
for (double x : nums) acc += x;
|
| 568 |
+
return acc;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
if (op == MathOp::Mul) {
|
| 572 |
+
double acc = 1.0;
|
| 573 |
+
for (double x : nums) acc *= x;
|
| 574 |
+
return acc;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
if (op == MathOp::Div) {
|
| 578 |
+
double acc = nums[0];
|
| 579 |
+
for (size_t i = 1; i < nums.size(); ++i) {
|
| 580 |
+
if (nums[i] == 0.0) return std::numeric_limits<double>::quiet_NaN();
|
| 581 |
+
acc /= nums[i];
|
| 582 |
+
}
|
| 583 |
+
return acc;
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
if (op == MathOp::Sub) {
|
| 587 |
+
if (reverse_sub && nums.size() >= 2) {
|
| 588 |
+
double acc = nums.back();
|
| 589 |
+
for (size_t i = nums.size() - 1; i-- > 0;) {
|
| 590 |
+
acc -= nums[i];
|
| 591 |
+
}
|
| 592 |
+
return acc;
|
| 593 |
+
} else {
|
| 594 |
+
double acc = nums[0];
|
| 595 |
+
for (size_t i = 1; i < nums.size(); ++i) acc -= nums[i];
|
| 596 |
+
return acc;
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
// Default arithmetic fallback: sum.
|
| 601 |
+
double acc = 0.0;
|
| 602 |
+
for (double x : nums) acc += x;
|
| 603 |
+
return acc;
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
static std::optional<double>
|
| 607 |
+
evaluate_numeric_segment(const std::vector<std::string> &segment, int def_depth) {
|
| 608 |
+
std::vector<double> nums;
|
| 609 |
+
nums.reserve(segment.size());
|
| 610 |
+
|
| 611 |
+
for (const auto &raw : segment) {
|
| 612 |
+
const std::string surf = trim_math_surface(raw);
|
| 613 |
+
double v = 0.0;
|
| 614 |
+
if (is_number_surface(surf, v)) nums.push_back(v);
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
if (nums.empty()) return std::nullopt;
|
| 618 |
+
if (nums.size() == 1) return nums[0];
|
| 619 |
+
|
| 620 |
+
const MathOp op = infer_dominant_operator(segment, def_depth);
|
| 621 |
+
|
| 622 |
+
bool reverse_sub = false;
|
| 623 |
+
for (const auto &raw : segment) {
|
| 624 |
+
const std::string t = normalize_dictionary_key(raw);
|
| 625 |
+
if (is_reversal_marker(t)) {
|
| 626 |
+
reverse_sub = true;
|
| 627 |
+
break;
|
| 628 |
+
}
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
if (op == MathOp::Unknown) {
|
| 632 |
+
// Best-effort fallback: if there are multiple numerals and no clear cue,
|
| 633 |
+
// addition is the conservative default for many count/total prompts.
|
| 634 |
+
return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
if (op == MathOp::Sub) return fold_numbers_iterative(nums, MathOp::Sub, reverse_sub);
|
| 638 |
+
if (op == MathOp::Add) return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 639 |
+
if (op == MathOp::Mul) return fold_numbers_iterative(nums, MathOp::Mul, false);
|
| 640 |
+
if (op == MathOp::Div) return fold_numbers_iterative(nums, MathOp::Div, false);
|
| 641 |
+
|
| 642 |
+
// Numeric segment for comparison-like clauses: use the arithmetic value.
|
| 643 |
+
return fold_numbers_iterative(nums, MathOp::Add, false);
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
static std::optional<bool>
|
| 647 |
+
evaluate_comparison_segment(const std::vector<std::string> &segment, int def_depth) {
|
| 648 |
+
size_t cue_pos = std::numeric_limits<size_t>::max();
|
| 649 |
+
MathOp cmp = MathOp::Unknown;
|
| 650 |
+
double best = -1.0;
|
| 651 |
+
|
| 652 |
+
for (size_t i = 0; i < segment.size(); ++i) {
|
| 653 |
+
const std::string t = trim_math_surface(segment[i]);
|
| 654 |
+
|
| 655 |
+
MathOp cand = MathOp::Unknown;
|
| 656 |
+
if (t == "<") cand = MathOp::Lt;
|
| 657 |
+
else if (t == "<=") cand = MathOp::Le;
|
| 658 |
+
else if (t == ">") cand = MathOp::Gt;
|
| 659 |
+
else if (t == ">=") cand = MathOp::Ge;
|
| 660 |
+
else if (t == "=" || t == "==") cand = MathOp::Eq;
|
| 661 |
+
else if (normalize_dictionary_key(segment[i]) == "less") cand = MathOp::Lt;
|
| 662 |
+
else if (normalize_dictionary_key(segment[i]) == "greater") cand = MathOp::Gt;
|
| 663 |
+
else if (normalize_dictionary_key(segment[i]) == "equal" || normalize_dictionary_key(segment[i]) == "equals")
|
| 664 |
+
cand = MathOp::Eq;
|
| 665 |
+
|
| 666 |
+
if (cand != MathOp::Unknown) {
|
| 667 |
+
const double s = score_operator(segment, cand, def_depth);
|
| 668 |
+
if (s > best) {
|
| 669 |
+
best = s;
|
| 670 |
+
cmp = cand;
|
| 671 |
+
cue_pos = i;
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
if (cmp == MathOp::Unknown || cue_pos == std::numeric_limits<size_t>::max()) return std::nullopt;
|
| 677 |
+
|
| 678 |
+
std::vector<std::string> left(segment.begin(), segment.begin() + static_cast<std::ptrdiff_t>(cue_pos));
|
| 679 |
+
std::vector<std::string> right(segment.begin() + static_cast<std::ptrdiff_t>(cue_pos) + 1, segment.end());
|
| 680 |
+
|
| 681 |
+
auto lv = evaluate_numeric_segment(left, def_depth);
|
| 682 |
+
auto rv = evaluate_numeric_segment(right, def_depth);
|
| 683 |
+
if (!lv.has_value() || !rv.has_value()) return std::nullopt;
|
| 684 |
+
|
| 685 |
+
switch (cmp) {
|
| 686 |
+
case MathOp::Lt: return *lv < *rv;
|
| 687 |
+
case MathOp::Le: return *lv <= *rv;
|
| 688 |
+
case MathOp::Gt: return *lv > *rv;
|
| 689 |
+
case MathOp::Ge: return *lv >= *rv;
|
| 690 |
+
case MathOp::Eq: return std::fabs(*lv - *rv) < 1e-12;
|
| 691 |
+
default: return std::nullopt;
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
static std::optional<bool>
|
| 696 |
+
evaluate_logic_chain(const std::vector<std::vector<std::string>> &segments, int def_depth) {
|
| 697 |
+
if (segments.empty()) return std::nullopt;
|
| 698 |
+
|
| 699 |
+
std::vector<bool> vals;
|
| 700 |
+
vals.reserve(segments.size());
|
| 701 |
+
std::vector<std::string> ops;
|
| 702 |
+
ops.reserve(segments.size());
|
| 703 |
+
|
| 704 |
+
for (const auto &seg : segments) {
|
| 705 |
+
bool negate = false;
|
| 706 |
+
size_t first_sig = 0;
|
| 707 |
+
while (first_sig < seg.size() &&
|
| 708 |
+
normalize_dictionary_key(seg[first_sig]).empty()) {
|
| 709 |
+
++first_sig;
|
| 710 |
+
}
|
| 711 |
+
if (first_sig < seg.size() && normalize_dictionary_key(seg[first_sig]) == "not") {
|
| 712 |
+
negate = true;
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
std::optional<bool> v = evaluate_comparison_segment(seg, def_depth);
|
| 716 |
+
if (!v.has_value()) {
|
| 717 |
+
auto n = evaluate_numeric_segment(seg, def_depth);
|
| 718 |
+
if (!n.has_value()) return std::nullopt;
|
| 719 |
+
v = (*n != 0.0);
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
bool b = *v;
|
| 723 |
+
if (negate) b = !b;
|
| 724 |
+
vals.push_back(b);
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
// If the original prompt had "and/or" tokens, combine them left-to-right.
|
| 728 |
+
// Otherwise, a single boolean clause is returned.
|
| 729 |
+
if (segments.size() == 1) return vals[0];
|
| 730 |
+
|
| 731 |
+
bool acc = vals[0];
|
| 732 |
+
size_t seg_idx = 1;
|
| 733 |
+
for (size_t i = 0; i < segments.size() - 1; ++i) {
|
| 734 |
+
// A simple left-to-right boolean fold. If the user wrote mixed logic,
|
| 735 |
+
// this is conservative and iterative.
|
| 736 |
+
const std::vector<std::string> &seg = segments[i];
|
| 737 |
+
bool saw_or = false;
|
| 738 |
+
bool saw_and = false;
|
| 739 |
+
for (const auto &raw : seg) {
|
| 740 |
+
const std::string t = normalize_dictionary_key(raw);
|
| 741 |
+
if (t == "or") saw_or = true;
|
| 742 |
+
if (t == "and") saw_and = true;
|
| 743 |
+
}
|
| 744 |
+
if (saw_or && !saw_and) acc = acc || vals[seg_idx++];
|
| 745 |
+
else acc = acc && vals[seg_idx++];
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
return acc;
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
static std::vector<std::vector<std::string>>
|
| 752 |
+
split_into_math_segments(const std::vector<std::string> &prompt_toks) {
|
| 753 |
+
std::vector<std::vector<std::string>> segs;
|
| 754 |
+
std::vector<std::string> cur;
|
| 755 |
+
segs.reserve(8);
|
| 756 |
+
|
| 757 |
+
for (const auto &tok : prompt_toks) {
|
| 758 |
+
const std::string t = normalize_dictionary_key(tok);
|
| 759 |
+
|
| 760 |
+
if (is_sentence_break_token(tok)) {
|
| 761 |
+
if (!cur.empty()) {
|
| 762 |
+
segs.push_back(std::move(cur));
|
| 763 |
+
cur.clear();
|
| 764 |
+
}
|
| 765 |
+
continue;
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
if (t == "then" || t == "after" || t == "before") {
|
| 769 |
+
if (!cur.empty()) {
|
| 770 |
+
segs.push_back(std::move(cur));
|
| 771 |
+
cur.clear();
|
| 772 |
+
}
|
| 773 |
+
continue;
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
cur.push_back(tok);
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
if (!cur.empty()) segs.push_back(std::move(cur));
|
| 780 |
+
return segs;
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
static std::optional<std::string>
|
| 784 |
+
try_math_branch(const std::vector<std::string> &prompt_toks, int def_depth) {
|
| 785 |
+
if (prompt_toks.empty()) return std::nullopt;
|
| 786 |
+
|
| 787 |
+
auto segments = split_into_math_segments(prompt_toks);
|
| 788 |
+
if (segments.empty()) return std::nullopt;
|
| 789 |
+
|
| 790 |
+
// If any segment contains logic/comparison language, try boolean evaluation.
|
| 791 |
+
bool has_logic = false;
|
| 792 |
+
bool has_cmp = false;
|
| 793 |
+
for (const auto &seg : segments) {
|
| 794 |
+
if (has_any_logic_token(seg)) has_logic = true;
|
| 795 |
+
if (evaluate_comparison_segment(seg, def_depth).has_value()) has_cmp = true;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
if (has_logic || has_cmp) {
|
| 799 |
+
auto b = evaluate_logic_chain(segments, def_depth);
|
| 800 |
+
if (b.has_value()) {
|
| 801 |
+
const std::string out = *b ? "True" : "False";
|
| 802 |
+
std::cout << out << ' ' << std::flush;
|
| 803 |
+
return out;
|
| 804 |
+
}
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
// Otherwise choose the strongest arithmetic-looking segment.
|
| 808 |
+
double best_score = -1.0;
|
| 809 |
+
const std::vector<std::string> *best_seg = nullptr;
|
| 810 |
+
|
| 811 |
+
for (const auto &seg : segments) {
|
| 812 |
+
const MathOp op = infer_dominant_operator(seg, def_depth);
|
| 813 |
+
double s = score_operator(seg, op, def_depth);
|
| 814 |
+
|
| 815 |
+
// Numeric content increases confidence.
|
| 816 |
+
size_t num_count = 0;
|
| 817 |
+
for (const auto &raw : seg) {
|
| 818 |
+
double v = 0.0;
|
| 819 |
+
if (is_number_surface(trim_math_surface(raw), v)) ++num_count;
|
| 820 |
+
}
|
| 821 |
+
s += 0.03 * static_cast<double>(num_count);
|
| 822 |
+
|
| 823 |
+
if (s > best_score) {
|
| 824 |
+
best_score = s;
|
| 825 |
+
best_seg = &seg;
|
| 826 |
+
}
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
if (!best_seg || best_score < 0.12) return std::nullopt;
|
| 830 |
+
|
| 831 |
+
auto comp = evaluate_comparison_segment(*best_seg, def_depth);
|
| 832 |
+
if (comp.has_value()) return *comp ? std::string("True") : std::string("False");
|
| 833 |
+
|
| 834 |
+
auto num = evaluate_numeric_segment(*best_seg, def_depth);
|
| 835 |
+
if (num.has_value()) {
|
| 836 |
+
std::ostringstream oss;
|
| 837 |
+
const double rounded = std::round(*num);
|
| 838 |
+
if (std::fabs(*num - rounded) < 1e-12) oss << static_cast<long long>(rounded);
|
| 839 |
+
else oss << *num;
|
| 840 |
+
|
| 841 |
+
const std::string out = oss.str();
|
| 842 |
+
std::cout << out << ' ' << std::flush;
|
| 843 |
+
return out;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
std::ostringstream oss;
|
| 847 |
+
const double rounded = std::round(*num);
|
| 848 |
+
if (std::fabs(*num - rounded) < 1e-12) oss << static_cast<long long>(rounded);
|
| 849 |
+
else oss << *num;
|
| 850 |
+
return oss.str();
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
// --------------------------- String interning (short methods) --------------
|
| 854 |
|
| 855 |
using StrPtr = const std::string*;
|
|
|
|
| 1381 |
std::vector<std::string> resp;
|
| 1382 |
if (prompt_toks.empty() || maxlen == 0) return resp;
|
| 1383 |
|
| 1384 |
+
if (auto math_out = try_math_branch(prompt_toks, static_cast<int>(kb.def_depth)); math_out.has_value()) {
|
| 1385 |
+
return std::vector<std::string>{*math_out};
|
| 1386 |
+
}
|
| 1387 |
+
|
| 1388 |
auto prompt_ptrs = intern_tokens(kb, prompt_toks);
|
| 1389 |
std::vector<StrPtr> resp_ptrs;
|
| 1390 |
std::unordered_map<std::string,int> recent_counts;
|