technician1 commited on
Commit
7a6f4fc
·
verified ·
1 Parent(s): 318aeed

Update ChatIPC.cpp

Browse files
Files changed (1) hide show
  1. ChatIPC.cpp +665 -665
ChatIPC.cpp CHANGED
@@ -1,665 +1,665 @@
1
- // ChatIPC := Chat Incremental Pattern Constructor
2
-
3
- #include <algorithm>
4
- #include <atomic>
5
- #include <cctype>
6
- #include <cinttypes>
7
- #include <cstring>
8
- #include <fstream>
9
- #include <iostream>
10
- #include <iterator>
11
- #include <map>
12
- #include <mutex>
13
- #include <optional>
14
- #include <sstream>
15
- #include <stdexcept>
16
- #include <string>
17
- #include <thread>
18
- #include <unordered_map>
19
- #include <unordered_set>
20
- #include <vector>
21
-
22
- #ifdef _OPENMP
23
- #include <omp.h>
24
- #else
25
- inline int omp_get_max_threads(){ return 1; }
26
- inline int omp_get_thread_num(){ return 0; }
27
- #endif
28
-
29
- extern unsigned char dictionary_json[]; // provide dictionary.cpp to embed dictionary JSON bytes
30
- extern unsigned int dictionary_json_len;
31
-
32
- // --------------------------- Short utility functions ----------------------
33
-
34
- static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
35
- static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
36
- static inline void safe_flush(std::ostream &os){ os.flush(); }
37
-
38
- // Tokenize by whitespace
39
- static std::vector<std::string> tokenize_whitespace(const std::string &s){
40
- std::istringstream iss(s);
41
- std::vector<std::string> out;
42
- std::string t;
43
- while (iss >> t) out.push_back(t);
44
- return out;
45
- }
46
-
47
- // Tokenize by non-alphanumeric characters (for definitions)
48
- static std::vector<std::string> tokenize_non_alnum(const std::string &s){
49
- std::vector<std::string> out; std::string cur;
50
- for (char ch : s){
51
- if (std::isalnum(static_cast<unsigned char>(ch)) || ch=='-' || ch=='\''){
52
- cur.push_back(to_low(ch));
53
- } else {
54
- if (!cur.empty()){ out.push_back(cur); cur.clear(); }
55
- }
56
- }
57
- if (!cur.empty()) out.push_back(cur);
58
- return out;
59
- }
60
-
61
- // --------------------------- String interning (short methods) --------------
62
-
63
- struct StringInterner {
64
- std::unordered_set<std::string> pool;
65
- std::mutex m;
66
- const std::string* intern(const std::string &s){
67
- std::lock_guard<std::mutex> lk(m);
68
- auto it = pool.find(s);
69
- if (it != pool.end()) return &*it;
70
- auto pr = pool.insert(s);
71
- return &*pr.first;
72
- }
73
- };
74
-
75
- // ---------- Global parsed dictionary (populated once in main) ----------
76
- static std::unordered_map<std::string,std::string> g_raw_dict;
77
-
78
- static std::unordered_set<std::string> def_tokens_from_text(const std::string &s){
79
- auto toks = tokenize_non_alnum(s);
80
- return std::unordered_set<std::string>(toks.begin(), toks.end());
81
- }
82
-
83
- // --------------------------- Knowledge base (short methods) --------------
84
- using StrPtr = const std::string*;
85
- struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<std::string>()(*p); } };
86
- struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return *a == *b; } };
87
-
88
- using NextSet = std::vector<StrPtr>;
89
-
90
- struct KnowledgeBase {
91
- StringInterner interner;
92
- std::unordered_map<StrPtr, NextSet, PtrHash, PtrEq> next;
93
- std::mutex m;
94
-
95
- // def-index: for each interned word pointer -> list of interned tokens (definition expansion)
96
- std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
97
- std::mutex def_m;
98
- int def_depth = 0;
99
-
100
- void add_pair_interned(StrPtr k, StrPtr v){
101
- std::lock_guard<std::mutex> lk(m);
102
- auto &vec = next[k];
103
- for (auto p : vec) if (*p == *v) return;
104
- vec.push_back(v);
105
- }
106
-
107
- // set def depth; if changed, drop previously computed def expansions
108
- void set_def_depth(int D){
109
- std::lock_guard<std::mutex> lk(def_m);
110
- if (D != def_depth){
111
- def_index.clear();
112
- def_depth = D;
113
- }
114
- }
115
-
116
- // compute definition expansion for a single interned word (if needed)
117
- void ensure_def_for_interned(StrPtr wp){
118
- // quick no-op checks
119
- if (wp == nullptr) return;
120
- if (def_depth <= 0) return;
121
-
122
- // double-checked locking
123
- {
124
- std::lock_guard<std::mutex> lk(def_m);
125
- if (def_index.find(wp) != def_index.end()) return;
126
- }
127
-
128
- // compute expansion using global parsed dictionary g_raw_dict
129
- std::unordered_set<std::string> acc;
130
- std::vector<std::string> frontier;
131
- auto it_raw = g_raw_dict.find(*wp);
132
- if (it_raw != g_raw_dict.end()){
133
- auto toks = def_tokens_from_text(it_raw->second);
134
- for (auto &t : toks){
135
- if (acc.insert(t).second) frontier.push_back(t);
136
- }
137
- }
138
-
139
- for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
140
- std::vector<std::string> nextf;
141
- for (auto &w : frontier){
142
- auto it2 = g_raw_dict.find(w);
143
- if (it2 == g_raw_dict.end()) continue;
144
- auto toks2 = def_tokens_from_text(it2->second);
145
- for (auto &t : toks2){
146
- if (acc.insert(t).second) nextf.push_back(t);
147
- }
148
- }
149
- frontier.swap(nextf);
150
- }
151
-
152
- // intern all accumulated tokens and store pointers
153
- std::vector<StrPtr> out;
154
- out.reserve(acc.size());
155
- for (auto &s : acc){
156
- out.push_back(interner.intern(s));
157
- }
158
-
159
- // store atomically (prevent double insertion)
160
- {
161
- std::lock_guard<std::mutex> lk(def_m);
162
- // another thread may have inserted meanwhile; do not overwrite
163
- if (def_index.find(wp) == def_index.end()){
164
- def_index.emplace(wp, std::move(out));
165
- }
166
- }
167
- }
168
-
169
- // existing public add_pair but now ensure def-expansion is built immediately
170
- void add_pair(const std::string &k, const std::string &v){
171
- StrPtr kp = interner.intern(k);
172
- StrPtr vp = interner.intern(v);
173
- // ensure definition expansion for both words as soon as they are seen
174
- ensure_def_for_interned(kp);
175
- ensure_def_for_interned(vp);
176
- add_pair_interned(kp, vp);
177
- }
178
-
179
- std::optional<NextSet> lookup_by_string(const std::string &k) const {
180
- for (auto &pr : next) if (*pr.first == k) return pr.second;
181
- return std::nullopt;
182
- }
183
- std::optional<NextSet> lookup_by_ptr(StrPtr k) const {
184
- auto it = next.find(k);
185
- if (it==next.end()) return std::nullopt;
186
- return it->second;
187
- }
188
- };
189
-
190
- // thread-safe snapshot of kb.def_index as string-based def-index
191
- static std::unordered_map<std::string,std::unordered_set<std::string>>
192
- snapshot_def_index(KnowledgeBase &kb){
193
- std::unordered_map<std::string,std::unordered_set<std::string>> out;
194
- std::lock_guard<std::mutex> lk(kb.def_m);
195
- out.reserve(kb.def_index.size());
196
- for (auto &pr : kb.def_index){
197
- std::unordered_set<std::string> s;
198
- s.reserve(pr.second.size());
199
- for (auto p : pr.second) s.insert(*p);
200
- out.emplace(*pr.first, std::move(s));
201
- }
202
- return out;
203
- }
204
-
205
- // --------------------------- Small JSON parse helpers ----------------------
206
-
207
- static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
208
-
209
- static std::string parse_quoted_string(const std::string &text, size_t &i){
210
- std::string out;
211
- if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'");
212
- ++i;
213
- while (json_valid_index(i, text.size())){
214
- char c = text[i++];
215
- if (c == '"') break;
216
- if (c == '\\'){
217
- if (!json_valid_index(i, text.size())) break;
218
- char e = text[i++];
219
- if (e=='n') out.push_back('\n');
220
- else if (e=='t') out.push_back('\t');
221
- else out.push_back(e);
222
- } else out.push_back(c);
223
- }
224
- return out;
225
- }
226
-
227
- static void skip_spaces(const std::string &s, size_t &i){
228
- while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
229
- }
230
-
231
- // Very small JSON-like parser tailored to dictionary_json structure
232
- static std::unordered_map<std::string,std::string> parse_dictionary_json(){
233
- std::unordered_map<std::string,std::string> dict;
234
- if (dictionary_json_len == 0) return dict;
235
- std::string text; text.reserve(dictionary_json_len + 1);
236
- for (unsigned int b=0; b < dictionary_json_len; ++b) text.push_back(static_cast<char>(dictionary_json[b]));
237
- size_t i = 0;
238
- skip_spaces(text,i);
239
- if (!json_valid_index(i,text.size()) || text[i] != '{') return dict;
240
- ++i;
241
- while (true){
242
- skip_spaces(text,i);
243
- if (!json_valid_index(i,text.size())) break;
244
- if (text[i] == '}'){ ++i; break; }
245
- std::string key = parse_quoted_string(text,i);
246
- skip_spaces(text,i);
247
- if (!json_valid_index(i,text.size()) || text[i] != ':') break;
248
- ++i;
249
- skip_spaces(text,i);
250
- std::string val;
251
- if (json_valid_index(i,text.size()) && text[i] == '"') val = parse_quoted_string(text,i);
252
- else {
253
- size_t start = i;
254
- while (json_valid_index(i,text.size()) && text[i] != ',' && text[i] != '}') ++i;
255
- val = text.substr(start, i-start);
256
- }
257
- dict.emplace(std::move(key), std::move(val));
258
- skip_spaces(text,i);
259
- if (json_valid_index(i,text.size()) && text[i] == ','){ ++i; continue; }
260
- if (json_valid_index(i,text.size()) && text[i] == '}'){ ++i; break; }
261
- }
262
- return dict;
263
- }
264
-
265
- // --------------------------- Similarity helpers (very small) ----------------
266
-
267
- static double jaccard_similarity(const std::unordered_set<std::string> &A,
268
- const std::unordered_set<std::string> &B)
269
- {
270
- if (A.empty() && B.empty()) return 1.0;
271
- size_t inter = 0;
272
- if (A.size() < B.size()){
273
- for (const auto &x : A) if (B.count(x)) ++inter;
274
- } else {
275
- for (const auto &x : B) if (A.count(x)) ++inter;
276
- }
277
- size_t uni = A.size() + B.size() - inter;
278
- if (uni == 0) return 0.0;
279
- return static_cast<double>(inter) / static_cast<double>(uni);
280
- }
281
-
282
- static std::unordered_set<std::string>
283
- aggregate_sets(const std::vector<std::string> &tokens,
284
- const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index)
285
- {
286
- std::unordered_set<std::string> agg;
287
- for (auto &t : tokens){
288
- agg.insert(t);
289
- auto it = def_index.find(t);
290
- if (it != def_index.end()){
291
- for (auto &d : it->second) agg.insert(d);
292
- }
293
- }
294
- return agg;
295
- }
296
-
297
- // --------------------------- Candidate selection (short funcs) ---------------
298
-
299
- static std::string best_candidate_by_similarity(const NextSet &cands,
300
- const std::vector<std::string> &prompt_toks,
301
- const std::vector<std::string> &resp_toks,
302
- const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index,
303
- const std::unordered_map<std::string,int> &recent_counts,
304
- double repeat_penalty)
305
- {
306
- if (cands.empty()) return std::string();
307
- if (cands.size() == 1) return *cands[0];
308
-
309
- auto agg = aggregate_sets(prompt_toks, def_index);
310
- for (auto &r : resp_toks){
311
- auto it = def_index.find(r);
312
- if (it != def_index.end()) for (auto &d : it->second) agg.insert(d);
313
- }
314
-
315
- double best = -1e9;
316
- std::string best_tok;
317
- size_t M = cands.size();
318
- std::vector<double> scores(M, 0.0);
319
-
320
- #pragma omp parallel for schedule(static)
321
- for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(M);++i){
322
- std::unordered_set<std::string> candset;
323
- candset.insert(*cands[(size_t)i]);
324
- auto it = def_index.find(*cands[(size_t)i]);
325
- if (it != def_index.end()) for (auto &d : it->second) candset.insert(d);
326
- double s = jaccard_similarity(agg, candset);
327
- scores[(size_t)i] = s;
328
- }
329
-
330
- for (size_t i=0;i<M;++i){
331
- const std::string &tok = *cands[i];
332
- double s = scores[i];
333
- auto rc_it = recent_counts.find(tok);
334
- int cnt = (rc_it==recent_counts.end()? 0 : rc_it->second);
335
- double adjusted = s - repeat_penalty * static_cast<double>(cnt);
336
- if (adjusted > best || (adjusted == best && tok < best_tok)){
337
- best = adjusted;
338
- best_tok = tok;
339
- }
340
- }
341
- return best_tok;
342
- }
343
-
344
- // --------------------------- Response constructor (short units) ---------------
345
-
346
- static std::vector<std::string> construct_response(KnowledgeBase &kb,
347
- const std::vector<std::string> &prompt_toks,
348
- size_t maxlen,
349
- const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index,
350
- double repeat_penalty)
351
- {
352
- std::vector<std::string> resp;
353
- if (prompt_toks.empty() || maxlen == 0) return resp;
354
- std::unordered_map<std::string,int> recent_counts;
355
-
356
- auto would_create_2_cycle = [&](const std::string &cand) -> bool {
357
- if (resp.size() < 2) return false;
358
- // check alternation: X Y X Y ... then candidate == X and last == Y
359
- const std::string &last = resp.back();
360
- const std::string &prev = resp[resp.size()-2];
361
- return (cand == prev && last == resp[resp.size()-3 < resp.size() ? resp.size()-3 : 0]);
362
- // this is a cheap conservative check; main guard is repeat_penalty + single-candidate rule
363
- };
364
-
365
- std::string last_printed;
366
- for (size_t step=0; step<maxlen; ++step){
367
- NextSet candidates;
368
- bool found = false;
369
- if (step==0){
370
- for (ssize_t p = static_cast<ssize_t>(prompt_toks.size())-1; p>=0; --p){
371
- auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
372
- if (opt){ candidates = *opt; found = true; break; }
373
- }
374
- } else {
375
- auto opt = kb.lookup_by_string(last_printed);
376
- if (opt){ candidates = *opt; found = true; }
377
- else {
378
- for (ssize_t p = static_cast<ssize_t>(prompt_toks.size())-1; p>=0; --p){
379
- auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
380
- if (opt2){ candidates = *opt2; found = true; break; }
381
- }
382
- }
383
- }
384
- if (!found || candidates.empty()) break;
385
-
386
- // If only one candidate and it already appeared, stop to avoid 1-cycle.
387
- if (candidates.size()==1){
388
- std::string only = *candidates[0];
389
- if (recent_counts[only] > 0) break;
390
- resp.push_back(only);
391
- recent_counts[only] += 1;
392
- last_printed = only;
393
- std::cout << only << ' ' << std::flush; // print immediately
394
- continue;
395
- }
396
-
397
- // choose best with repeat penalty
398
- std::string chosen = best_candidate_by_similarity(candidates, prompt_toks, resp, def_index, recent_counts, repeat_penalty);
399
- if (chosen.empty()) break;
400
-
401
- // cheap 2-cycle avoider: if this would continue a trivial alternation, stop
402
- if (would_create_2_cycle(chosen)) break;
403
-
404
- resp.push_back(chosen);
405
- recent_counts[chosen] += 1;
406
- last_printed = chosen;
407
- std::cout << chosen << ' ' << std::flush; // print immediately
408
- }
409
- return resp;
410
- }
411
-
412
- // --------------------------- Learning from files (short) -------------------
413
-
414
- static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
415
- std::ifstream ifs(fname);
416
- if (!ifs) return;
417
- std::string tok;
418
- std::string prev;
419
- bool have_prev = false;
420
- while (ifs >> tok){
421
- if (have_prev) kb.add_pair(prev, tok);
422
- prev = tok; have_prev = true;
423
- }
424
- }
425
-
426
- static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::string> &files){
427
- #pragma omp parallel for schedule(dynamic)
428
- for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
429
- }
430
-
431
- // --------------------------- Serialization (short functions) ----------------
432
-
433
- static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
434
- std::ofstream ofs(fname, std::ios::binary);
435
- if (!ofs) throw std::runtime_error("cannot open save file");
436
-
437
- // interned strings snapshot (must include all tokens used by def_index)
438
- std::vector<const std::string*> interned;
439
- interned.reserve(kb.interner.pool.size());
440
- for (auto &s : kb.interner.pool) interned.push_back(&s);
441
-
442
- uint64_t N = interned.size();
443
- ofs.write(reinterpret_cast<const char*>(&N), sizeof(N));
444
- for (auto p : interned){
445
- uint64_t L = p->size();
446
- ofs.write(reinterpret_cast<const char*>(&L), sizeof(L));
447
- ofs.write(p->data(), static_cast<std::streamsize>(L));
448
- }
449
-
450
- // edges
451
- uint64_t E = kb.next.size();
452
- ofs.write(reinterpret_cast<const char*>(&E), sizeof(E));
453
- for (auto &pr : kb.next){
454
- const std::string &key = *pr.first;
455
- auto it = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == key; });
456
- if (it == interned.end()) throw std::runtime_error("save index error");
457
- uint64_t key_idx = static_cast<uint64_t>(std::distance(interned.begin(), it));
458
- ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
459
- uint64_t M = pr.second.size();
460
- ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
461
- for (auto nxt : pr.second){
462
- auto it2 = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == *nxt; });
463
- if (it2 == interned.end()) throw std::runtime_error("save index error2");
464
- uint64_t v_idx = static_cast<uint64_t>(std::distance(interned.begin(), it2));
465
- ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
466
- }
467
- }
468
-
469
- // --- write definition expansion section ---
470
- uint64_t D = static_cast<uint64_t>(kb.def_depth);
471
- ofs.write(reinterpret_cast<const char*>(&D), sizeof(D));
472
-
473
- // def entries: number of keys with a stored expansion
474
- uint64_t K = kb.def_index.size();
475
- ofs.write(reinterpret_cast<const char*>(&K), sizeof(K));
476
- for (auto &pr : kb.def_index){
477
- // key index
478
- const std::string &key = *pr.first;
479
- auto it = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == key; });
480
- if (it == interned.end()) throw std::runtime_error("save def index error");
481
- uint64_t key_idx = static_cast<uint64_t>(std::distance(interned.begin(), it));
482
- ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
483
-
484
- // number of tokens
485
- uint64_t M = pr.second.size();
486
- ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
487
- for (auto tokp : pr.second){
488
- auto it2 = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == *tokp; });
489
- if (it2 == interned.end()) throw std::runtime_error("save def token index error");
490
- uint64_t v_idx = static_cast<uint64_t>(std::distance(interned.begin(), it2));
491
- ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
492
- }
493
- }
494
-
495
- safe_flush(ofs);
496
- }
497
-
498
- static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
499
- std::ifstream ifs(fname, std::ios::binary);
500
- if (!ifs) throw std::runtime_error("cannot open load file");
501
-
502
- uint64_t N;
503
- ifs.read(reinterpret_cast<char*>(&N), sizeof(N));
504
- std::vector<std::string> strings; strings.reserve((size_t)N);
505
- for (uint64_t i=0;i<N;++i){
506
- uint64_t L; ifs.read(reinterpret_cast<char*>(&L), sizeof(L));
507
- std::string s; s.resize((size_t)L);
508
- ifs.read(&s[0], static_cast<std::streamsize>(L));
509
- strings.push_back(std::move(s));
510
- }
511
- std::vector<StrPtr> ptrs; ptrs.reserve(strings.size());
512
- for (auto &s : strings) ptrs.push_back(kb.interner.intern(s));
513
-
514
- uint64_t E; ifs.read(reinterpret_cast<char*>(&E), sizeof(E));
515
- for (uint64_t i=0;i<E;++i){
516
- uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
517
- uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
518
- StrPtr key_ptr = ptrs.at((size_t)key_idx);
519
- NextSet vec; vec.reserve((size_t)M);
520
- for (uint64_t j=0;j<M;++j){
521
- uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
522
- vec.push_back(ptrs.at((size_t)v_idx));
523
- }
524
- kb.next.emplace(key_ptr, std::move(vec));
525
- }
526
-
527
- // read def-expansion section (new-format)
528
- uint64_t file_def_depth;
529
- ifs.read(reinterpret_cast<char*>(&file_def_depth), sizeof(file_def_depth));
530
- uint64_t K; ifs.read(reinterpret_cast<char*>(&K), sizeof(K));
531
- // populate kb.def_index from file
532
- {
533
- std::lock_guard<std::mutex> lk(kb.def_m);
534
- kb.def_index.clear();
535
- kb.def_depth = static_cast<int>(file_def_depth);
536
- }
537
- for (uint64_t i=0;i<K;++i){
538
- uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
539
- uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
540
- std::vector<StrPtr> tokens; tokens.reserve((size_t)M);
541
- for (uint64_t j=0;j<M;++j){
542
- uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
543
- tokens.push_back(ptrs.at((size_t)v_idx));
544
- }
545
- kb.def_index.emplace(ptrs.at((size_t)key_idx), std::move(tokens));
546
- }
547
-
548
- // If CLI requested a different dict depth, clear and recompute expansion for loaded words only
549
- if (cli_dict_depth != kb.def_depth){
550
- kb.set_def_depth(cli_dict_depth);
551
- // --- build deduplicated union of "words present" = saved strings (ptrs) ∪ KB words (keys and neighbors)
552
- std::vector<StrPtr> targets;
553
- targets.reserve(ptrs.size() + kb.next.size()*2);
554
-
555
- {
556
- std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
557
- // include all strings from the saved file
558
- for (auto p : ptrs) {
559
- if (seen.insert(p).second) targets.push_back(p);
560
- }
561
- // include all words present in KB edges (keys and their neighbors)
562
- for (auto &pr : kb.next) {
563
- if (seen.insert(pr.first).second) targets.push_back(pr.first);
564
- for (auto v : pr.second) {
565
- if (seen.insert(v).second) targets.push_back(v);
566
- }
567
- }
568
- }
569
-
570
- // --- recompute definition expansion for each target in parallel
571
- #pragma omp parallel for schedule(dynamic)
572
- for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i) {
573
- kb.ensure_def_for_interned(targets[(size_t)i]);
574
- }
575
- }
576
- }
577
-
578
- // --------------------------- CLI + Interactive loop (shorters) -----------
579
-
580
- static void print_usage(const char *p){
581
- std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
582
- std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
583
- std::cout << " --save FILE Save the knowledge-base and dictionary expansions to a binary file.\n";
584
- std::cout << " --load-kb FILE Load a previously saved knowledge-base (and dictionary expansions) from a binary file.\n";
585
- std::cout << " --dict-depth D Depth of dictionary-definition expansion used during learning.\n";
586
- std::cout << " --learn f1 f2 ... Learn from one or more text files to update the knowledge base.\n";
587
- std::cout << " --repeat-penalty P Penalize repeated tokens during response generation (higher values discourage repetition).\n";
588
- std::cout << " --help Show command-line interface options for ChatIPC usage.\n";
589
- }
590
-
591
- int main(int argc, char **argv){
592
- size_t maxlen = 100;
593
- std::string savefile;
594
- std::string load_txt;
595
- std::string load_kb;
596
- int dict_depth = 2;
597
- double repeat_penalty = 0.7; // default λ
598
- std::vector<std::string> learn_files;
599
-
600
- for (int i=1;i<argc;++i){
601
- std::string a = argv[i];
602
- if (a=="--help"){ print_usage(argv[0]); return 0; }
603
- if (a=="--maxlen" && i+1<argc){ maxlen = std::stoul(argv[++i]); continue; }
604
- if (a=="--save" && i+1<argc){ savefile = argv[++i]; continue; }
605
- if (a=="--load-kb" && i+1<argc){ load_kb = argv[++i]; continue; }
606
- if (a=="--dict-depth" && i+1<argc){ dict_depth = std::stoi(argv[++i]); continue; }
607
- if (a=="--repeat-penalty" && i+1<argc){ repeat_penalty = std::stod(argv[++i]); continue; }
608
- if (a=="--learn"){
609
- ++i;
610
- for (; i<argc; ++i){
611
- if (!argv[i]) break;
612
- std::string s = argv[i];
613
- if (!s.empty() && s[0]=='-'){ --i; break; }
614
- learn_files.push_back(s);
615
- }
616
- continue;
617
- }
618
- learn_files.push_back(a);
619
- }
620
-
621
- KnowledgeBase kb;
622
-
623
- // parse the embedded dictionary once for use by per-word expansion
624
- g_raw_dict = parse_dictionary_json();
625
- // set KB def depth (clears any previous expansion)
626
- kb.set_def_depth(dict_depth);
627
-
628
-
629
- if (!load_kb.empty()){
630
- try { load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
631
- catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
632
- }
633
-
634
- if (!learn_files.empty()){
635
- std::cerr << "Learning from file/s (" << learn_files.size() << ") using threads=" << omp_get_max_threads() << "\n";
636
- learn_files_parallel(kb, learn_files);
637
- }
638
-
639
- std::string line;
640
- std::cout << "Ready. Enter prompts.\n";
641
- while (std::cout << "> " , std::getline(std::cin, line)){
642
- if (line.empty()){ std::cout << "\n"; continue; }
643
- auto prompt_toks = tokenize_whitespace(line);
644
- for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
645
- auto def_index = snapshot_def_index(kb);
646
- auto resp = construct_response(kb, prompt_toks, maxlen, def_index, repeat_penalty);
647
- std::cout << "\n";
648
- if (!resp.empty()){
649
- std::vector<std::string> combined = prompt_toks;
650
- combined.insert(combined.end(), resp.begin(), resp.end());
651
- for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
652
- }
653
- if (!savefile.empty()){
654
- try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
655
- catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
656
- }
657
- }
658
-
659
- if (!savefile.empty()){
660
- try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
661
- catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
662
- }
663
-
664
- return 0;
665
- }
 
1
+ // ChatIPC := Chat Incremental Pattern Constructor
2
+
3
+ #include <algorithm>
4
+ #include <atomic>
5
+ #include <cctype>
6
+ #include <cinttypes>
7
+ #include <cstring>
8
+ #include <fstream>
9
+ #include <iostream>
10
+ #include <iterator>
11
+ #include <map>
12
+ #include <mutex>
13
+ #include <optional>
14
+ #include <sstream>
15
+ #include <stdexcept>
16
+ #include <string>
17
+ #include <thread>
18
+ #include <unordered_map>
19
+ #include <unordered_set>
20
+ #include <vector>
21
+
22
+ #ifdef _OPENMP
23
+ #include <omp.h>
24
+ #else
25
+ inline int omp_get_max_threads(){ return 1; }
26
+ inline int omp_get_thread_num(){ return 0; }
27
+ #endif
28
+
29
+ extern unsigned char dictionary_json[]; // provide dictionary.cpp to embed dictionary JSON bytes
30
+ extern unsigned int dictionary_json_len;
31
+
32
+ // --------------------------- Short utility functions ----------------------
33
+
34
+ static inline bool is_space(char c){ return std::isspace(static_cast<unsigned char>(c)) != 0; }
35
+ static inline char to_low(char c){ return static_cast<char>(std::tolower(static_cast<unsigned char>(c))); }
36
+ static inline void safe_flush(std::ostream &os){ os.flush(); }
37
+
38
+ // Tokenize by whitespace
39
+ static std::vector<std::string> tokenize_whitespace(const std::string &s){
40
+ std::istringstream iss(s);
41
+ std::vector<std::string> out;
42
+ std::string t;
43
+ while (iss >> t) out.push_back(t);
44
+ return out;
45
+ }
46
+
47
+ // Tokenize by non-alphanumeric characters (for definitions)
48
+ static std::vector<std::string> tokenize_non_alnum(const std::string &s){
49
+ std::vector<std::string> out; std::string cur;
50
+ for (char ch : s){
51
+ if (std::isalnum(static_cast<unsigned char>(ch)) || ch=='-' || ch=='\''){
52
+ cur.push_back(to_low(ch));
53
+ } else {
54
+ if (!cur.empty()){ out.push_back(cur); cur.clear(); }
55
+ }
56
+ }
57
+ if (!cur.empty()) out.push_back(cur);
58
+ return out;
59
+ }
60
+
61
+ // --------------------------- String interning (short methods) --------------
62
+
63
+ struct StringInterner {
64
+ std::unordered_set<std::string> pool;
65
+ std::mutex m;
66
+ const std::string* intern(const std::string &s){
67
+ std::lock_guard<std::mutex> lk(m);
68
+ auto it = pool.find(s);
69
+ if (it != pool.end()) return &*it;
70
+ auto pr = pool.insert(s);
71
+ return &*pr.first;
72
+ }
73
+ };
74
+
75
+ // ---------- Global parsed dictionary (populated once in main) ----------
76
+ static std::unordered_map<std::string,std::string> g_raw_dict;
77
+
78
+ static std::unordered_set<std::string> def_tokens_from_text(const std::string &s){
79
+ auto toks = tokenize_non_alnum(s);
80
+ return std::unordered_set<std::string>(toks.begin(), toks.end());
81
+ }
82
+
83
+ // --------------------------- Knowledge base (short methods) --------------
84
+ using StrPtr = const std::string*;
85
+ struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash<std::string>()(*p); } };
86
+ struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return *a == *b; } };
87
+
88
+ using NextSet = std::vector<StrPtr>;
89
+
90
+ struct KnowledgeBase {
91
+ StringInterner interner;
92
+ std::unordered_map<StrPtr, NextSet, PtrHash, PtrEq> next;
93
+ std::mutex m;
94
+
95
+ // def-index: for each interned word pointer -> list of interned tokens (definition expansion)
96
+ std::unordered_map<StrPtr, std::vector<StrPtr>, PtrHash, PtrEq> def_index;
97
+ std::mutex def_m;
98
+ int def_depth = 0;
99
+
100
+ void add_pair_interned(StrPtr k, StrPtr v){
101
+ std::lock_guard<std::mutex> lk(m);
102
+ auto &vec = next[k];
103
+ for (auto p : vec) if (*p == *v) return;
104
+ vec.push_back(v);
105
+ }
106
+
107
+ // set def depth; if changed, drop previously computed def expansions
108
+ void set_def_depth(int D){
109
+ std::lock_guard<std::mutex> lk(def_m);
110
+ if (D != def_depth){
111
+ def_index.clear();
112
+ def_depth = D;
113
+ }
114
+ }
115
+
116
+ // compute definition expansion for a single interned word (if needed)
117
+ void ensure_def_for_interned(StrPtr wp){
118
+ // quick no-op checks
119
+ if (wp == nullptr) return;
120
+ if (def_depth <= 0) return;
121
+
122
+ // double-checked locking
123
+ {
124
+ std::lock_guard<std::mutex> lk(def_m);
125
+ if (def_index.find(wp) != def_index.end()) return;
126
+ }
127
+
128
+ // compute expansion using global parsed dictionary g_raw_dict
129
+ std::unordered_set<std::string> acc;
130
+ std::vector<std::string> frontier;
131
+ auto it_raw = g_raw_dict.find(*wp);
132
+ if (it_raw != g_raw_dict.end()){
133
+ auto toks = def_tokens_from_text(it_raw->second);
134
+ for (auto &t : toks){
135
+ if (acc.insert(t).second) frontier.push_back(t);
136
+ }
137
+ }
138
+
139
+ for (int depth = 1; depth < def_depth && !frontier.empty(); ++depth){
140
+ std::vector<std::string> nextf;
141
+ for (auto &w : frontier){
142
+ auto it2 = g_raw_dict.find(w);
143
+ if (it2 == g_raw_dict.end()) continue;
144
+ auto toks2 = def_tokens_from_text(it2->second);
145
+ for (auto &t : toks2){
146
+ if (acc.insert(t).second) nextf.push_back(t);
147
+ }
148
+ }
149
+ frontier.swap(nextf);
150
+ }
151
+
152
+ // intern all accumulated tokens and store pointers
153
+ std::vector<StrPtr> out;
154
+ out.reserve(acc.size());
155
+ for (auto &s : acc){
156
+ out.push_back(interner.intern(s));
157
+ }
158
+
159
+ // store atomically (prevent double insertion)
160
+ {
161
+ std::lock_guard<std::mutex> lk(def_m);
162
+ // another thread may have inserted meanwhile; do not overwrite
163
+ if (def_index.find(wp) == def_index.end()){
164
+ def_index.emplace(wp, std::move(out));
165
+ }
166
+ }
167
+ }
168
+
169
+ // existing public add_pair but now ensure def-expansion is built immediately
170
+ void add_pair(const std::string &k, const std::string &v){
171
+ StrPtr kp = interner.intern(k);
172
+ StrPtr vp = interner.intern(v);
173
+ // ensure definition expansion for both words as soon as they are seen
174
+ ensure_def_for_interned(kp);
175
+ ensure_def_for_interned(vp);
176
+ add_pair_interned(kp, vp);
177
+ }
178
+
179
+ std::optional<NextSet> lookup_by_string(const std::string &k) const {
180
+ for (auto &pr : next) if (*pr.first == k) return pr.second;
181
+ return std::nullopt;
182
+ }
183
+ std::optional<NextSet> lookup_by_ptr(StrPtr k) const {
184
+ auto it = next.find(k);
185
+ if (it==next.end()) return std::nullopt;
186
+ return it->second;
187
+ }
188
+ };
189
+
190
+ // thread-safe snapshot of kb.def_index as string-based def-index
191
+ static std::unordered_map<std::string,std::unordered_set<std::string>>
192
+ snapshot_def_index(KnowledgeBase &kb){
193
+ std::unordered_map<std::string,std::unordered_set<std::string>> out;
194
+ std::lock_guard<std::mutex> lk(kb.def_m);
195
+ out.reserve(kb.def_index.size());
196
+ for (auto &pr : kb.def_index){
197
+ std::unordered_set<std::string> s;
198
+ s.reserve(pr.second.size());
199
+ for (auto p : pr.second) s.insert(*p);
200
+ out.emplace(*pr.first, std::move(s));
201
+ }
202
+ return out;
203
+ }
204
+
205
+ // --------------------------- Small JSON parse helpers ----------------------
206
+
207
+ static inline bool json_valid_index(size_t i, size_t n){ return i < n; }
208
+
209
+ static std::string parse_quoted_string(const std::string &text, size_t &i){
210
+ std::string out;
211
+ if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'");
212
+ ++i;
213
+ while (json_valid_index(i, text.size())){
214
+ char c = text[i++];
215
+ if (c == '"') break;
216
+ if (c == '\\'){
217
+ if (!json_valid_index(i, text.size())) break;
218
+ char e = text[i++];
219
+ if (e=='n') out.push_back('\n');
220
+ else if (e=='t') out.push_back('\t');
221
+ else out.push_back(e);
222
+ } else out.push_back(c);
223
+ }
224
+ return out;
225
+ }
226
+
227
+ static void skip_spaces(const std::string &s, size_t &i){
228
+ while (json_valid_index(i, s.size()) && is_space(s[i])) ++i;
229
+ }
230
+
231
+ // Very small JSON-like parser tailored to dictionary_json structure
232
+ static std::unordered_map<std::string,std::string> parse_dictionary_json(){
233
+ std::unordered_map<std::string,std::string> dict;
234
+ if (dictionary_json_len == 0) return dict;
235
+ std::string text; text.reserve(dictionary_json_len + 1);
236
+ for (unsigned int b=0; b < dictionary_json_len; ++b) text.push_back(static_cast<char>(dictionary_json[b]));
237
+ size_t i = 0;
238
+ skip_spaces(text,i);
239
+ if (!json_valid_index(i,text.size()) || text[i] != '{') return dict;
240
+ ++i;
241
+ while (true){
242
+ skip_spaces(text,i);
243
+ if (!json_valid_index(i,text.size())) break;
244
+ if (text[i] == '}'){ ++i; break; }
245
+ std::string key = parse_quoted_string(text,i);
246
+ skip_spaces(text,i);
247
+ if (!json_valid_index(i,text.size()) || text[i] != ':') break;
248
+ ++i;
249
+ skip_spaces(text,i);
250
+ std::string val;
251
+ if (json_valid_index(i,text.size()) && text[i] == '"') val = parse_quoted_string(text,i);
252
+ else {
253
+ size_t start = i;
254
+ while (json_valid_index(i,text.size()) && text[i] != ',' && text[i] != '}') ++i;
255
+ val = text.substr(start, i-start);
256
+ }
257
+ dict.emplace(std::move(key), std::move(val));
258
+ skip_spaces(text,i);
259
+ if (json_valid_index(i,text.size()) && text[i] == ','){ ++i; continue; }
260
+ if (json_valid_index(i,text.size()) && text[i] == '}'){ ++i; break; }
261
+ }
262
+ return dict;
263
+ }
264
+
265
+ // --------------------------- Similarity helpers (very small) ----------------
266
+
267
+ static double jaccard_similarity(const std::unordered_set<std::string> &A,
268
+ const std::unordered_set<std::string> &B)
269
+ {
270
+ if (A.empty() && B.empty()) return 1.0;
271
+ size_t inter = 0;
272
+ if (A.size() < B.size()){
273
+ for (const auto &x : A) if (B.count(x)) ++inter;
274
+ } else {
275
+ for (const auto &x : B) if (A.count(x)) ++inter;
276
+ }
277
+ size_t uni = A.size() + B.size() - inter;
278
+ if (uni == 0) return 0.0;
279
+ return static_cast<double>(inter) / static_cast<double>(uni);
280
+ }
281
+
282
+ static std::unordered_set<std::string>
283
+ aggregate_sets(const std::vector<std::string> &tokens,
284
+ const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index)
285
+ {
286
+ std::unordered_set<std::string> agg;
287
+ for (auto &t : tokens){
288
+ agg.insert(t);
289
+ auto it = def_index.find(t);
290
+ if (it != def_index.end()){
291
+ for (auto &d : it->second) agg.insert(d);
292
+ }
293
+ }
294
+ return agg;
295
+ }
296
+
297
+ // --------------------------- Candidate selection (short funcs) ---------------
298
+
299
+ static std::string best_candidate_by_similarity(const NextSet &cands,
300
+ const std::vector<std::string> &prompt_toks,
301
+ const std::vector<std::string> &resp_toks,
302
+ const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index,
303
+ const std::unordered_map<std::string,int> &recent_counts,
304
+ double repeat_penalty)
305
+ {
306
+ if (cands.empty()) return std::string();
307
+ if (cands.size() == 1) return *cands[0];
308
+
309
+ auto agg = aggregate_sets(prompt_toks, def_index);
310
+ for (auto &r : resp_toks){
311
+ auto it = def_index.find(r);
312
+ if (it != def_index.end()) for (auto &d : it->second) agg.insert(d);
313
+ }
314
+
315
+ double best = -1e9;
316
+ std::string best_tok;
317
+ size_t M = cands.size();
318
+ std::vector<double> scores(M, 0.0);
319
+
320
+ #pragma omp parallel for schedule(static)
321
+ for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(M);++i){
322
+ std::unordered_set<std::string> candset;
323
+ candset.insert(*cands[(size_t)i]);
324
+ auto it = def_index.find(*cands[(size_t)i]);
325
+ if (it != def_index.end()) for (auto &d : it->second) candset.insert(d);
326
+ double s = jaccard_similarity(agg, candset);
327
+ scores[(size_t)i] = s;
328
+ }
329
+
330
+ for (size_t i=0;i<M;++i){
331
+ const std::string &tok = *cands[i];
332
+ double s = scores[i];
333
+ auto rc_it = recent_counts.find(tok);
334
+ int cnt = (rc_it==recent_counts.end()? 0 : rc_it->second);
335
+ double adjusted = s - repeat_penalty * static_cast<double>(cnt);
336
+ if (adjusted > best || (adjusted == best && tok < best_tok)){
337
+ best = adjusted;
338
+ best_tok = tok;
339
+ }
340
+ }
341
+ return best_tok;
342
+ }
343
+
344
+ // --------------------------- Response constructor (short units) ---------------
345
+
346
+ static std::vector<std::string> construct_response(KnowledgeBase &kb,
347
+ const std::vector<std::string> &prompt_toks,
348
+ size_t maxlen,
349
+ const std::unordered_map<std::string,std::unordered_set<std::string>> &def_index,
350
+ double repeat_penalty)
351
+ {
352
+ std::vector<std::string> resp;
353
+ if (prompt_toks.empty() || maxlen == 0) return resp;
354
+ std::unordered_map<std::string,int> recent_counts;
355
+
356
+ auto would_create_2_cycle = [&](const std::string &cand) -> bool {
357
+ if (resp.size() < 2) return false;
358
+ // check alternation: X Y X Y ... then candidate == X and last == Y
359
+ const std::string &last = resp.back();
360
+ const std::string &prev = resp[resp.size()-2];
361
+ return (cand == prev && last == resp[resp.size()-3 < resp.size() ? resp.size()-3 : 0]);
362
+ // this is a cheap conservative check; main guard is repeat_penalty + single-candidate rule
363
+ };
364
+
365
+ std::string last_printed;
366
+ for (size_t step=0; step<maxlen; ++step){
367
+ NextSet candidates;
368
+ bool found = false;
369
+ if (step==0){
370
+ for (ssize_t p = static_cast<ssize_t>(prompt_toks.size())-1; p>=0; --p){
371
+ auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]);
372
+ if (opt){ candidates = *opt; found = true; break; }
373
+ }
374
+ } else {
375
+ auto opt = kb.lookup_by_string(last_printed);
376
+ if (opt){ candidates = *opt; found = true; }
377
+ else {
378
+ for (ssize_t p = static_cast<ssize_t>(prompt_toks.size())-1; p>=0; --p){
379
+ auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]);
380
+ if (opt2){ candidates = *opt2; found = true; break; }
381
+ }
382
+ }
383
+ }
384
+ if (!found || candidates.empty()) break;
385
+
386
+ // If only one candidate and it already appeared, stop to avoid 1-cycle.
387
+ if (candidates.size()==1){
388
+ std::string only = *candidates[0];
389
+ if (recent_counts[only] > 0) break;
390
+ resp.push_back(only);
391
+ recent_counts[only] += 1;
392
+ last_printed = only;
393
+ std::cout << only << ' ' << std::flush; // print immediately
394
+ continue;
395
+ }
396
+
397
+ // choose best with repeat penalty
398
+ std::string chosen = best_candidate_by_similarity(candidates, prompt_toks, resp, def_index, recent_counts, repeat_penalty);
399
+ if (chosen.empty()) break;
400
+
401
+ // cheap 2-cycle avoider: if this would continue a trivial alternation, stop
402
+ if (would_create_2_cycle(chosen)) break;
403
+
404
+ resp.push_back(chosen);
405
+ recent_counts[chosen] += 1;
406
+ last_printed = chosen;
407
+ std::cout << chosen << ' ' << std::flush; // print immediately
408
+ }
409
+ return resp;
410
+ }
411
+
412
+ // --------------------------- Learning from files (short) -------------------
413
+
414
+ static void learn_from_file(KnowledgeBase &kb, const std::string &fname){
415
+ std::ifstream ifs(fname);
416
+ if (!ifs) return;
417
+ std::string tok;
418
+ std::string prev;
419
+ bool have_prev = false;
420
+ while (ifs >> tok){
421
+ if (have_prev) kb.add_pair(prev, tok);
422
+ prev = tok; have_prev = true;
423
+ }
424
+ }
425
+
426
+ static void learn_files_parallel(KnowledgeBase &kb, const std::vector<std::string> &files){
427
+ #pragma omp parallel for schedule(dynamic)
428
+ for (ptrdiff_t i=0;i<static_cast<ptrdiff_t>(files.size());++i) learn_from_file(kb, files[(size_t)i]);
429
+ }
430
+
431
+ // --------------------------- Serialization (short functions) ----------------
432
+
433
+ static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){
434
+ std::ofstream ofs(fname, std::ios::binary);
435
+ if (!ofs) throw std::runtime_error("cannot open save file");
436
+
437
+ // interned strings snapshot (must include all tokens used by def_index)
438
+ std::vector<const std::string*> interned;
439
+ interned.reserve(kb.interner.pool.size());
440
+ for (auto &s : kb.interner.pool) interned.push_back(&s);
441
+
442
+ uint64_t N = interned.size();
443
+ ofs.write(reinterpret_cast<const char*>(&N), sizeof(N));
444
+ for (auto p : interned){
445
+ uint64_t L = p->size();
446
+ ofs.write(reinterpret_cast<const char*>(&L), sizeof(L));
447
+ ofs.write(p->data(), static_cast<std::streamsize>(L));
448
+ }
449
+
450
+ // edges
451
+ uint64_t E = kb.next.size();
452
+ ofs.write(reinterpret_cast<const char*>(&E), sizeof(E));
453
+ for (auto &pr : kb.next){
454
+ const std::string &key = *pr.first;
455
+ auto it = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == key; });
456
+ if (it == interned.end()) throw std::runtime_error("save index error");
457
+ uint64_t key_idx = static_cast<uint64_t>(std::distance(interned.begin(), it));
458
+ ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
459
+ uint64_t M = pr.second.size();
460
+ ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
461
+ for (auto nxt : pr.second){
462
+ auto it2 = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == *nxt; });
463
+ if (it2 == interned.end()) throw std::runtime_error("save index error2");
464
+ uint64_t v_idx = static_cast<uint64_t>(std::distance(interned.begin(), it2));
465
+ ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
466
+ }
467
+ }
468
+
469
+ // --- write definition expansion section ---
470
+ uint64_t D = static_cast<uint64_t>(kb.def_depth);
471
+ ofs.write(reinterpret_cast<const char*>(&D), sizeof(D));
472
+
473
+ // def entries: number of keys with a stored expansion
474
+ uint64_t K = kb.def_index.size();
475
+ ofs.write(reinterpret_cast<const char*>(&K), sizeof(K));
476
+ for (auto &pr : kb.def_index){
477
+ // key index
478
+ const std::string &key = *pr.first;
479
+ auto it = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == key; });
480
+ if (it == interned.end()) throw std::runtime_error("save def index error");
481
+ uint64_t key_idx = static_cast<uint64_t>(std::distance(interned.begin(), it));
482
+ ofs.write(reinterpret_cast<const char*>(&key_idx), sizeof(key_idx));
483
+
484
+ // number of tokens
485
+ uint64_t M = pr.second.size();
486
+ ofs.write(reinterpret_cast<const char*>(&M), sizeof(M));
487
+ for (auto tokp : pr.second){
488
+ auto it2 = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == *tokp; });
489
+ if (it2 == interned.end()) throw std::runtime_error("save def token index error");
490
+ uint64_t v_idx = static_cast<uint64_t>(std::distance(interned.begin(), it2));
491
+ ofs.write(reinterpret_cast<const char*>(&v_idx), sizeof(v_idx));
492
+ }
493
+ }
494
+
495
+ safe_flush(ofs);
496
+ }
497
+
498
+ static void load_kb_binary(KnowledgeBase &kb, const std::string &fname, int cli_dict_depth){
499
+ std::ifstream ifs(fname, std::ios::binary);
500
+ if (!ifs) throw std::runtime_error("cannot open load file");
501
+
502
+ uint64_t N;
503
+ ifs.read(reinterpret_cast<char*>(&N), sizeof(N));
504
+ std::vector<std::string> strings; strings.reserve((size_t)N);
505
+ for (uint64_t i=0;i<N;++i){
506
+ uint64_t L; ifs.read(reinterpret_cast<char*>(&L), sizeof(L));
507
+ std::string s; s.resize((size_t)L);
508
+ ifs.read(&s[0], static_cast<std::streamsize>(L));
509
+ strings.push_back(std::move(s));
510
+ }
511
+ std::vector<StrPtr> ptrs; ptrs.reserve(strings.size());
512
+ for (auto &s : strings) ptrs.push_back(kb.interner.intern(s));
513
+
514
+ uint64_t E; ifs.read(reinterpret_cast<char*>(&E), sizeof(E));
515
+ for (uint64_t i=0;i<E;++i){
516
+ uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
517
+ uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
518
+ StrPtr key_ptr = ptrs.at((size_t)key_idx);
519
+ NextSet vec; vec.reserve((size_t)M);
520
+ for (uint64_t j=0;j<M;++j){
521
+ uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
522
+ vec.push_back(ptrs.at((size_t)v_idx));
523
+ }
524
+ kb.next.emplace(key_ptr, std::move(vec));
525
+ }
526
+
527
+ // read def-expansion section (new-format)
528
+ uint64_t file_def_depth;
529
+ ifs.read(reinterpret_cast<char*>(&file_def_depth), sizeof(file_def_depth));
530
+ uint64_t K; ifs.read(reinterpret_cast<char*>(&K), sizeof(K));
531
+ // populate kb.def_index from file
532
+ {
533
+ std::lock_guard<std::mutex> lk(kb.def_m);
534
+ kb.def_index.clear();
535
+ kb.def_depth = static_cast<int>(file_def_depth);
536
+ }
537
+ for (uint64_t i=0;i<K;++i){
538
+ uint64_t key_idx; ifs.read(reinterpret_cast<char*>(&key_idx), sizeof(key_idx));
539
+ uint64_t M; ifs.read(reinterpret_cast<char*>(&M), sizeof(M));
540
+ std::vector<StrPtr> tokens; tokens.reserve((size_t)M);
541
+ for (uint64_t j=0;j<M;++j){
542
+ uint64_t v_idx; ifs.read(reinterpret_cast<char*>(&v_idx), sizeof(v_idx));
543
+ tokens.push_back(ptrs.at((size_t)v_idx));
544
+ }
545
+ kb.def_index.emplace(ptrs.at((size_t)key_idx), std::move(tokens));
546
+ }
547
+
548
+ // If CLI requested a different dict depth, clear and recompute expansion for loaded words only
549
+ if (cli_dict_depth != kb.def_depth){
550
+ kb.set_def_depth(cli_dict_depth);
551
+ // --- build deduplicated union of "words present" = saved strings (ptrs) ∪ KB words (keys and neighbors)
552
+ std::vector<StrPtr> targets;
553
+ targets.reserve(ptrs.size() + kb.next.size()*2);
554
+
555
+ {
556
+ std::unordered_set<StrPtr, PtrHash, PtrEq> seen;
557
+ // include all strings from the saved file
558
+ for (auto p : ptrs) {
559
+ if (seen.insert(p).second) targets.push_back(p);
560
+ }
561
+ // include all words present in KB edges (keys and their neighbors)
562
+ for (auto &pr : kb.next) {
563
+ if (seen.insert(pr.first).second) targets.push_back(pr.first);
564
+ for (auto v : pr.second) {
565
+ if (seen.insert(v).second) targets.push_back(v);
566
+ }
567
+ }
568
+ }
569
+
570
+ // --- recompute definition expansion for each target in parallel
571
+ #pragma omp parallel for schedule(dynamic)
572
+ for (ptrdiff_t i = 0; i < static_cast<ptrdiff_t>(targets.size()); ++i) {
573
+ kb.ensure_def_for_interned(targets[(size_t)i]);
574
+ }
575
+ }
576
+ }
577
+
578
+ // --------------------------- CLI + Interactive loop (shorters) -----------
579
+
580
+ static void print_usage(const char *p){
581
+ std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...] [--repeat-penalty P] [--help]\n";
582
+ std::cout << " --maxlen N Maximum number of tokens constructed in a response.\n";
583
+ std::cout << " --save FILE Save the knowledge-base and dictionary expansions to a binary file.\n";
584
+ std::cout << " --load-kb FILE Load a previously saved knowledge-base (and dictionary expansions) from a binary file.\n";
585
+ std::cout << " --dict-depth D Depth of dictionary-definition expansion used during learning.\n";
586
+ std::cout << " --learn f1 f2 ... Learn from one or more text files to update the knowledge base.\n";
587
+ std::cout << " --repeat-penalty P Penalize repeated tokens during response generation (higher values discourage repetition).\n";
588
+ std::cout << " --help Show command-line interface options for ChatIPC usage.\n";
589
+ }
590
+
591
+ int main(int argc, char **argv){
592
+ size_t maxlen = 100;
593
+ std::string savefile;
594
+ std::string load_txt;
595
+ std::string load_kb;
596
+ int dict_depth = 2;
597
+ double repeat_penalty = 0.7; // default λ
598
+ std::vector<std::string> learn_files;
599
+
600
+ for (int i=1;i<argc;++i){
601
+ std::string a = argv[i];
602
+ if (a=="--help"){ print_usage(argv[0]); return 0; }
603
+ if (a=="--maxlen" && i+1<argc){ maxlen = std::stoul(argv[++i]); continue; }
604
+ if (a=="--save" && i+1<argc){ savefile = argv[++i]; continue; }
605
+ if (a=="--load-kb" && i+1<argc){ load_kb = argv[++i]; continue; }
606
+ if (a=="--dict-depth" && i+1<argc){ dict_depth = std::stoi(argv[++i]); continue; }
607
+ if (a=="--repeat-penalty" && i+1<argc){ repeat_penalty = std::stod(argv[++i]); continue; }
608
+ if (a=="--learn"){
609
+ ++i;
610
+ for (; i<argc; ++i){
611
+ if (!argv[i]) break;
612
+ std::string s = argv[i];
613
+ if (!s.empty() && s[0]=='-'){ --i; break; }
614
+ learn_files.push_back(s);
615
+ }
616
+ continue;
617
+ }
618
+ learn_files.push_back(a);
619
+ }
620
+
621
+ KnowledgeBase kb;
622
+
623
+ // parse the embedded dictionary once for use by per-word expansion
624
+ g_raw_dict = parse_dictionary_json();
625
+ // set KB def depth (clears any previous expansion)
626
+ kb.set_def_depth(dict_depth);
627
+
628
+
629
+ if (!load_kb.empty()){
630
+ try { load_kb_binary(kb, load_kb, dict_depth); std::cerr << "Loaded KB: " << load_kb << "\n"; }
631
+ catch (const std::exception &e){ std::cerr << "Load KB error: " << e.what() << "\n"; }
632
+ }
633
+
634
+ if (!learn_files.empty()){
635
+ std::cerr << "Learning from file/s (" << learn_files.size() << ") using threads=" << omp_get_max_threads() << "\n";
636
+ learn_files_parallel(kb, learn_files);
637
+ }
638
+
639
+ std::string line;
640
+ std::cout << "Ready. Enter prompts.\n";
641
+ while (std::cout << "> " , std::getline(std::cin, line)){
642
+ if (line.empty()){ std::cout << "\n"; continue; }
643
+ auto prompt_toks = tokenize_whitespace(line);
644
+ for (size_t i=1;i<prompt_toks.size();++i) kb.add_pair(prompt_toks[i-1], prompt_toks[i]);
645
+ auto def_index = snapshot_def_index(kb);
646
+ auto resp = construct_response(kb, prompt_toks, maxlen, def_index, repeat_penalty);
647
+ std::cout << "\n";
648
+ if (!resp.empty()){
649
+ std::vector<std::string> combined = prompt_toks;
650
+ combined.insert(combined.end(), resp.begin(), resp.end());
651
+ for (size_t i=1;i<combined.size();++i) kb.add_pair(combined[i-1], combined[i]);
652
+ }
653
+ if (!savefile.empty()){
654
+ try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
655
+ catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
656
+ }
657
+ }
658
+
659
+ if (!savefile.empty()){
660
+ try { save_kb_binary(kb, savefile); std::cerr << "Saved KB: " << savefile << "\n"; }
661
+ catch (const std::exception &e){ std::cerr << "Save KB error: " << e.what() << "\n"; }
662
+ }
663
+
664
+ return 0;
665
+ }