// chatipc_modular.cpp // Compile: g++ -std=c++17 -O2 -fopenmp -o chatipc_modular chatipc_modular.cpp // Requires dictionary.cpp providing: extern unsigned char dictionary_json[]; extern unsigned int dictionary_json_len; #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef _OPENMP #include #else inline int omp_get_max_threads(){ return 1; } inline int omp_get_thread_num(){ return 0; } #endif extern unsigned char dictionary_json[]; // provide dictionary.cpp to embed dictionary JSON bytes extern unsigned int dictionary_json_len; // --------------------------- Short utility functions ---------------------- static inline bool is_space(char c){ return std::isspace(static_cast(c)) != 0; } static inline char to_low(char c){ return static_cast(std::tolower(static_cast(c))); } static inline void safe_flush(std::ostream &os){ os.flush(); } // Tokenize by whitespace static std::vector tokenize_whitespace(const std::string &s){ std::istringstream iss(s); std::vector out; std::string t; while (iss >> t) out.push_back(t); return out; } // Tokenize by non-alphanumeric characters (for definitions) static std::vector tokenize_non_alnum(const std::string &s){ std::vector out; std::string cur; for (char ch : s){ if (std::isalnum(static_cast(ch)) || ch=='-' || ch=='\''){ cur.push_back(to_low(ch)); } else { if (!cur.empty()){ out.push_back(cur); cur.clear(); } } } if (!cur.empty()) out.push_back(cur); return out; } // --------------------------- String interning (short methods) -------------- struct StringInterner { std::unordered_set pool; std::mutex m; const std::string* intern(const std::string &s){ std::lock_guard lk(m); auto it = pool.find(s); if (it != pool.end()) return &*it; auto pr = pool.insert(s); return &*pr.first; } }; // --------------------------- Knowledge base (short methods) -------------- using StrPtr = const std::string*; struct PtrHash { size_t operator()(StrPtr p) const noexcept { return std::hash()(*p); } }; struct PtrEq { bool operator()(StrPtr a, StrPtr b) const noexcept { return *a == *b; } }; using NextSet = std::vector; struct KnowledgeBase { StringInterner interner; std::unordered_map next; std::mutex m; void add_pair_interned(StrPtr k, StrPtr v){ std::lock_guard lk(m); auto &vec = next[k]; for (auto p : vec) if (*p == *v) return; vec.push_back(v); } void add_pair(const std::string &k, const std::string &v){ StrPtr kp = interner.intern(k); StrPtr vp = interner.intern(v); add_pair_interned(kp, vp); } std::optional lookup_by_string(const std::string &k) const { for (auto &pr : next) if (*pr.first == k) return pr.second; return std::nullopt; } std::optional lookup_by_ptr(StrPtr k) const { auto it = next.find(k); if (it==next.end()) return std::nullopt; return it->second; } }; // --------------------------- Small JSON parse helpers ---------------------- static inline bool json_valid_index(size_t i, size_t n){ return i < n; } static std::string parse_quoted_string(const std::string &text, size_t &i){ std::string out; if (!json_valid_index(i, text.size()) || text[i] != '"') throw std::runtime_error("expected '\"'"); ++i; while (json_valid_index(i, text.size())){ char c = text[i++]; if (c == '"') break; if (c == '\\'){ if (!json_valid_index(i, text.size())) break; char e = text[i++]; if (e=='n') out.push_back('\n'); else if (e=='t') out.push_back('\t'); else out.push_back(e); } else out.push_back(c); } return out; } static void skip_spaces(const std::string &s, size_t &i){ while (json_valid_index(i, s.size()) && is_space(s[i])) ++i; } // Very small JSON-like parser tailored to dictionary_json structure static std::unordered_map parse_dictionary_json(){ std::unordered_map dict; if (dictionary_json_len == 0) return dict; std::string text; text.reserve(dictionary_json_len + 1); for (unsigned int b=0; b < dictionary_json_len; ++b) text.push_back(static_cast(dictionary_json[b])); size_t i = 0; skip_spaces(text,i); if (!json_valid_index(i,text.size()) || text[i] != '{') return dict; ++i; while (true){ skip_spaces(text,i); if (!json_valid_index(i,text.size())) break; if (text[i] == '}'){ ++i; break; } std::string key = parse_quoted_string(text,i); skip_spaces(text,i); if (!json_valid_index(i,text.size()) || text[i] != ':') break; ++i; skip_spaces(text,i); std::string val; if (json_valid_index(i,text.size()) && text[i] == '"') val = parse_quoted_string(text,i); else { size_t start = i; while (json_valid_index(i,text.size()) && text[i] != ',' && text[i] != '}') ++i; val = text.substr(start, i-start); } dict.emplace(std::move(key), std::move(val)); skip_spaces(text,i); if (json_valid_index(i,text.size()) && text[i] == ','){ ++i; continue; } if (json_valid_index(i,text.size()) && text[i] == '}'){ ++i; break; } } return dict; } // --------------------------- Build definition index (small funcs) --------- static std::unordered_set def_tokens_from_text(const std::string &s){ auto toks = tokenize_non_alnum(s); return std::unordered_set(toks.begin(), toks.end()); } static void expand_def_index(const std::unordered_map> &direct, std::unordered_map> &out, int depth) { for (auto &pr : direct){ const std::string &word = pr.first; std::unordered_set acc = pr.second; if (depth > 1){ std::vector frontier(acc.begin(), acc.end()); for (int d=1; d nextf; for (auto &w : frontier){ auto it = direct.find(w); if (it==direct.end()) continue; for (auto &t : it->second){ if (acc.insert(t).second) nextf.push_back(t); } } if (nextf.empty()) break; frontier.swap(nextf); } } out.emplace(word, std::move(acc)); } } static std::unordered_map> build_definition_index(int depth) { std::unordered_map> out; if (depth <= 0) return out; auto raw = parse_dictionary_json(); std::unordered_map> direct; for (auto &pr : raw) direct.emplace(pr.first, def_tokens_from_text(pr.second)); expand_def_index(direct, out, depth); return out; } // --------------------------- Similarity helpers (very small) ---------------- static double jaccard_similarity(const std::unordered_set &A, const std::unordered_set &B) { if (A.empty() && B.empty()) return 1.0; size_t inter = 0; if (A.size() < B.size()){ for (const auto &x : A) if (B.count(x)) ++inter; } else { for (const auto &x : B) if (A.count(x)) ++inter; } size_t uni = A.size() + B.size() - inter; if (uni == 0) return 0.0; return static_cast(inter) / static_cast(uni); } static std::unordered_set aggregate_sets(const std::vector &tokens, const std::unordered_map> &def_index) { std::unordered_set agg; for (auto &t : tokens){ agg.insert(t); auto it = def_index.find(t); if (it != def_index.end()){ for (auto &d : it->second) agg.insert(d); } } return agg; } // --------------------------- Candidate selection (short funcs) --------------- static std::string best_candidate_by_similarity(const NextSet &cands, const std::vector &prompt_toks, const std::vector &resp_toks, const std::unordered_map> &def_index, const std::unordered_map &recent_counts, double repeat_penalty) { if (cands.empty()) return std::string(); if (cands.size() == 1) return *cands[0]; auto agg = aggregate_sets(prompt_toks, def_index); for (auto &r : resp_toks){ auto it = def_index.find(r); if (it != def_index.end()) for (auto &d : it->second) agg.insert(d); } double best = -1e9; std::string best_tok; size_t M = cands.size(); std::vector scores(M, 0.0); #pragma omp parallel for schedule(static) for (ptrdiff_t i=0;i(M);++i){ std::unordered_set candset; candset.insert(*cands[(size_t)i]); auto it = def_index.find(*cands[(size_t)i]); if (it != def_index.end()) for (auto &d : it->second) candset.insert(d); double s = jaccard_similarity(agg, candset); scores[(size_t)i] = s; } for (size_t i=0;isecond); double adjusted = s - repeat_penalty * static_cast(cnt); if (adjusted > best || (adjusted == best && tok < best_tok)){ best = adjusted; best_tok = tok; } } return best_tok; } // --------------------------- Response generator (short units) --------------- static std::vector generate_response(KnowledgeBase &kb, const std::vector &prompt_toks, size_t maxlen, const std::unordered_map> &def_index, double repeat_penalty) { std::vector resp; if (prompt_toks.empty() || maxlen == 0) return resp; std::unordered_map recent_counts; auto would_create_2_cycle = [&](const std::string &cand) -> bool { if (resp.size() < 2) return false; // check alternation: X Y X Y ... then candidate == X and last == Y const std::string &last = resp.back(); const std::string &prev = resp[resp.size()-2]; return (cand == prev && last == resp[resp.size()-3 < resp.size() ? resp.size()-3 : 0]); // this is a cheap conservative check; main guard is repeat_penalty + single-candidate rule }; std::string last_printed; for (size_t step=0; step(prompt_toks.size())-1; p>=0; --p){ auto opt = kb.lookup_by_string(prompt_toks[(size_t)p]); if (opt){ candidates = *opt; found = true; break; } } } else { auto opt = kb.lookup_by_string(last_printed); if (opt){ candidates = *opt; found = true; } else { for (ssize_t p = static_cast(prompt_toks.size())-1; p>=0; --p){ auto opt2 = kb.lookup_by_string(prompt_toks[(size_t)p]); if (opt2){ candidates = *opt2; found = true; break; } } } } if (!found || candidates.empty()) break; // If only one candidate and it already appeared, stop to avoid 1-cycle. if (candidates.size()==1){ std::string only = *candidates[0]; if (recent_counts[only] > 0) break; resp.push_back(only); recent_counts[only] += 1; last_printed = only; continue; } // choose best with repeat penalty std::string chosen = best_candidate_by_similarity(candidates, prompt_toks, resp, def_index, recent_counts, repeat_penalty); if (chosen.empty()) break; // cheap 2-cycle avoider: if this would continue a trivial alternation, stop if (would_create_2_cycle(chosen)) break; resp.push_back(chosen); recent_counts[chosen] += 1; last_printed = chosen; } return resp; } // --------------------------- Learning from files (short) ------------------- static void learn_from_file(KnowledgeBase &kb, const std::string &fname){ std::ifstream ifs(fname); if (!ifs) return; std::string tok; std::string prev; bool have_prev = false; while (ifs >> tok){ if (have_prev) kb.add_pair(prev, tok); prev = tok; have_prev = true; } } static void learn_files_parallel(KnowledgeBase &kb, const std::vector &files){ #pragma omp parallel for schedule(dynamic) for (ptrdiff_t i=0;i(files.size());++i) learn_from_file(kb, files[(size_t)i]); } // --------------------------- Serialization (short functions) ---------------- // File format documented in comments near functions static void save_kb_binary(const KnowledgeBase &kb, const std::string &fname){ std::ofstream ofs(fname, std::ios::binary); if (!ofs) throw std::runtime_error("cannot open save file"); std::vector interned; interned.reserve(kb.interner.pool.size()); for (auto &s : kb.interner.pool) interned.push_back(&s); uint64_t N = interned.size(); ofs.write(reinterpret_cast(&N), sizeof(N)); for (auto p : interned){ uint64_t L = p->size(); ofs.write(reinterpret_cast(&L), sizeof(L)); ofs.write(p->data(), static_cast(L)); } uint64_t E = kb.next.size(); ofs.write(reinterpret_cast(&E), sizeof(E)); for (auto &pr : kb.next){ // find index of key const std::string &key = *pr.first; auto it = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == key; }); if (it == interned.end()) throw std::runtime_error("save index error"); uint64_t key_idx = static_cast(std::distance(interned.begin(), it)); ofs.write(reinterpret_cast(&key_idx), sizeof(key_idx)); uint64_t M = pr.second.size(); ofs.write(reinterpret_cast(&M), sizeof(M)); for (auto nxt : pr.second){ auto it2 = std::find_if(interned.begin(), interned.end(), [&](const std::string* s){ return *s == *nxt; }); if (it2 == interned.end()) throw std::runtime_error("save index error2"); uint64_t v_idx = static_cast(std::distance(interned.begin(), it2)); ofs.write(reinterpret_cast(&v_idx), sizeof(v_idx)); } } safe_flush(ofs); } static void load_kb_binary(KnowledgeBase &kb, const std::string &fname){ std::ifstream ifs(fname, std::ios::binary); if (!ifs) throw std::runtime_error("cannot open load file"); uint64_t N; ifs.read(reinterpret_cast(&N), sizeof(N)); std::vector strings; strings.reserve((size_t)N); for (uint64_t i=0;i(&L), sizeof(L)); std::string s; s.resize((size_t)L); ifs.read(&s[0], static_cast(L)); strings.push_back(std::move(s)); } std::vector ptrs; ptrs.reserve(strings.size()); for (auto &s : strings) ptrs.push_back(kb.interner.intern(s)); uint64_t E; ifs.read(reinterpret_cast(&E), sizeof(E)); for (uint64_t i=0;i(&key_idx), sizeof(key_idx)); uint64_t M; ifs.read(reinterpret_cast(&M), sizeof(M)); StrPtr key_ptr = ptrs.at((size_t)key_idx); NextSet vec; vec.reserve((size_t)M); for (uint64_t j=0;j(&v_idx), sizeof(v_idx)); vec.push_back(ptrs.at((size_t)v_idx)); } kb.next.emplace(key_ptr, std::move(vec)); } } // --------------------------- CLI + Interactive loop (shorters) ----------- static void print_usage(const char *p){ std::cout << "Usage: " << p << " [--maxlen N] [--save FILE] [--load-kb FILE] [--dict-depth D] [--learn f1 f2 ...]\n"; } int main(int argc, char **argv){ size_t maxlen = 100; std::string savefile; std::string load_txt; std::string load_kb; int dict_depth = 2; double repeat_penalty = 0.7; // default λ std::vector learn_files; for (int i=1;i " , std::getline(std::cin, line)){ if (line.empty()){ std::cout << "\n"; continue; } auto prompt_toks = tokenize_whitespace(line); for (size_t i=1;i combined = prompt_toks; combined.insert(combined.end(), resp.begin(), resp.end()); for (size_t i=1;i