| import re | |
| import random | |
| from collections import Counter, defaultdict | |
| from training_data import corpus | |
| from AGWM import * | |
| ModelName = 'AgGPT-14' | |
| def world_model(length =10): | |
| """Generates a simple world model for demonstration purposes.""" | |
| text_file = "training_data/WM.txt" | |
| model_file = "AGWM.json" | |
| if os.path.exists(model_file): | |
| chain = load_model(model_file) | |
| else: | |
| chain = train_and_save_model(text_file, model_file) | |
| return chain.generate(min_sentences=length) | |
| class AgGPT14: | |
| def __init__(self, corpus_text, order=3, seed=None): | |
| assert order >= 1, "order must be >= 1" | |
| self.model_name = ModelName | |
| self.order = order | |
| self.rng = random.Random(seed) | |
| self.pairs = self._parse_pairs(corpus_text) | |
| if not self.pairs: | |
| raise ValueError("No (user, ai) pairs found in corpus.") | |
| self.user_docs = [self._tokenize(u) for u, _ in self.pairs] | |
| self.ai_docs = [self._tokenize(a) for _, a in self.pairs] | |
| self.idf_weights = self._calculate_idf(self.user_docs) | |
| self.global_transitions = self._build_global_transitions(self.ai_docs) | |
| self.unigram = self._build_unigram(self.ai_docs) | |
| self.user_ai_pairs = list(zip(self.user_docs, self.ai_docs)) | |
| def _calculate_idf(self, docs): | |
| """Calculates an aggressive IDF score to emphasize rare words.""" | |
| N = len(docs) | |
| doc_freq = Counter() | |
| for doc in docs: | |
| for word in set(doc): | |
| doc_freq[word] += 1 | |
| idf = {word: (N / (count + 1)) ** 2 for word, count in doc_freq.items()} | |
| return idf | |
| def _lcs(self, a, b): | |
| """Finds the Longest Common Subsequence between two lists of tokens.""" | |
| lengths = [[0 for j in range(len(b) + 1)] for i in range(len(a) + 1)] | |
| for i, x in enumerate(a): | |
| for j, y in enumerate(b): | |
| if x == y: | |
| lengths[i + 1][j + 1] = lengths[i][j] + 1 | |
| else: | |
| lengths[i + 1][j + 1] = max(lengths[i + 1][j], lengths[i][j + 1]) | |
| result = [] | |
| x, y = len(a), len(b) | |
| while x != 0 and y != 0: | |
| if lengths[x][y] == lengths[x - 1][y]: | |
| x -= 1 | |
| elif lengths[x][y] == lengths[x][y - 1]: | |
| y -= 1 | |
| else: | |
| result.append(a[x - 1]) | |
| x -= 1 | |
| y -= 1 | |
| return result[::-1] | |
| def _parse_pairs(self, text): | |
| pattern = re.compile( | |
| r"user:\s*(.*?)\s*<pad>\s*ai:\s*(.*?)\s*<eos>", | |
| re.DOTALL | re.IGNORECASE | |
| ) | |
| pairs = [] | |
| for u, a in pattern.findall(text): | |
| u, a = u.strip(), a.strip() | |
| if u and a: | |
| pairs.append((u, a)) | |
| return pairs | |
| def _expand_contractions(self, s): | |
| s = re.sub(r"what's", "what is", s) | |
| s = re.sub(r"that's", "that is", s) | |
| s = re.sub(r"it's", "it is", s) | |
| s = re.sub(r"how's", "how is", s) | |
| s = re.sub(r"he's", "he is", s) | |
| s = re.sub(r"she's", "she is", s) | |
| s = re.sub(r"you're", "you are", s) | |
| s = re.sub(r"i'm", "i am", s) | |
| s = re.sub(r"didn't", "did not", s) | |
| s = re.sub(r"don't", "do not", s) | |
| s = re.sub(r"can't", "cannot", s) | |
| return s | |
| def _tokenize(self, s): | |
| s = s.strip().lower() | |
| s = self._expand_contractions(s) | |
| tokens = re.findall(r"[a-z]+(?:'[a-z]+)?|[?.!,;:]", s) | |
| return [t for t in tokens if t] | |
| def _with_bounds(self, tokens): | |
| return ["<s>"] * self.order + tokens + ["</s>"] | |
| def _similarity(self, query_tokens, doc_tokens): | |
| if not query_tokens or not doc_tokens: | |
| return 0.0 | |
| common_words = set(query_tokens).intersection(set(doc_tokens)) | |
| if not common_words: | |
| return 0.0 | |
| idf_score = sum(self.idf_weights.get(word, 0.1) for word in common_words) | |
| lcs = self._lcs(query_tokens, doc_tokens) | |
| order_bonus_factor = 0.5 | |
| order_bonus = sum(self.idf_weights.get(word, 0.1) for word in lcs) * order_bonus_factor | |
| return idf_score + order_bonus | |
| def _find_best_match(self, user_text): | |
| q_tokens = self._tokenize(user_text) | |
| if not q_tokens: | |
| return None | |
| best_score = -1.0 | |
| best_idx = -1 | |
| for i, user_doc in enumerate(self.user_docs): | |
| sim = self._similarity(q_tokens, user_doc) | |
| if sim > best_score: | |
| best_score = sim | |
| best_idx = i | |
| if best_idx == -1 or best_score < 0.1: | |
| return None | |
| return best_idx | |
| def _build_global_transitions(self, docs): | |
| trans = defaultdict(Counter) | |
| for tokens in docs: | |
| seq = self._with_bounds(tokens) | |
| for i in range(len(seq) - self.order): | |
| ctx = tuple(seq[i : i + self.order]) | |
| nxt = seq[i + self.order] | |
| trans[ctx][nxt] += 1 | |
| return trans | |
| def _build_unigram(self, docs): | |
| uni = Counter() | |
| for d in docs: | |
| uni.update(d) | |
| return uni | |
| def _get_best_starting_context(self, user_text): | |
| """Finds the best match and deterministically returns its starting context.""" | |
| best_match_idx = self._find_best_match(user_text) | |
| if best_match_idx is not None: | |
| ai_doc = self.ai_docs[best_match_idx] | |
| if len(ai_doc) >= self.order: | |
| return tuple(ai_doc[:self.order]) | |
| return tuple(["<s>"] * self.order) | |
| def _sample_next(self, context, temperature, top_k): | |
| ctx = context | |
| while len(ctx) > 0: | |
| if ctx in self.global_transitions and self.global_transitions[ctx]: | |
| counter = self.global_transitions[ctx] | |
| break | |
| ctx = ctx[1:] | |
| else: | |
| counter = Counter({k: v for k, v in self.unigram.items() if k not in ["<s>", "</s>"]}) | |
| if not counter: return "</s>" | |
| items = sorted(counter.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| if not items: return "</s>" | |
| if temperature <= 0: return items[0][0] | |
| tokens, weights = zip(*items) | |
| scaled_weights = [w ** (1.0 / temperature) for w in weights] | |
| return self.rng.choices(tokens, weights=scaled_weights, k=1)[0] | |
| def _detokenize(self, tokens): | |
| if not tokens: return "" | |
| text = " ".join(t for t in tokens if t not in ["<s>", "</s>"]) | |
| text = re.sub(r'\s+([?.!,;:])', r'\1', text) | |
| text = re.sub(r" ([']) ", r"\1", text) | |
| if text: text = text[0].upper() + text[1:] | |
| text = re.sub(r'([.!?]\s*)([a-z])', lambda m: m.group(1) + m.group(2).upper(), text) | |
| text = re.sub(r'\bi\b', 'I', text) | |
| return text | |
| def respond(self, user_text, max_tokens=25, temperature=0.7, top_k=8, use_context_selection=True): | |
| ctx = self._get_best_starting_context(user_text) if use_context_selection else tuple(["<s>"] * self.order) | |
| out = list(ctx) | |
| for _ in range(max_tokens): | |
| nxt = self._sample_next(ctx, temperature, top_k) | |
| if nxt == "</s>": break | |
| out.append(nxt) | |
| ctx = tuple(out[-self.order:]) | |
| return self._detokenize(out) | |
| def ask(self, prompt, text_world_model=False, **kwargs): | |
| """User-friendly wrapper for the respond method.""" | |
| response = self.respond(prompt, **kwargs) | |
| if text_world_model: | |
| wm_response = world_model(length=10) | |
| wm_response = "<world_model>" + wm_response + "</world_model>" | |
| response = wm_response + " " + response | |
| return response | |
| def get_debug_info(self, user_text): | |
| q_tokens = self._tokenize(user_text) | |
| print(f"--- Debug info for: '{user_text}' ---") | |
| print(f"Query Tokens (after normalization): {q_tokens}\n") | |
| best_match_idx = self._find_best_match(user_text) | |
| if best_match_idx is not None: | |
| best_score = self._similarity(q_tokens, self.user_docs[best_match_idx]) | |
| print("Determined Best Match:") | |
| print(f" - Corpus Entry: {' '.join(self.user_docs[best_match_idx])}") | |
| print(f" - Score: {best_score:.2f}") | |
| print(f" - Corresponding AI response will be used for context.") | |
| else: | |
| print("No suitable match found. Will use default starting context.") | |
| if __name__ == "__main__": | |
| print(f"Initializing model: {ModelName}") | |
| bot = AgGPT14(corpus, order=3, seed=42) | |
| print("\n=== Demonstrating the Fix for 'color' query ===") | |
| bot.get_debug_info("what is your favorite color?") | |
| print("\n=== Testing Model with Deterministic Matching ===") | |
| tests = [ | |
| "hi", | |
| "tell me a joke", | |
| "do you have hobbies?", | |
| "what is your favorite color?", | |
| "thanks a lot", | |
| ] | |
| for t in tests: | |
| print(f"user: {t}") | |
| response = bot.ask(t) | |
| print(f"ai: {response}") | |
| print("-" * 40) | |
| print("====WORLD MODEL====") | |
| print(world_model()) | |
| prompt = "hello, how are you?" | |
| print(f"\nPrompt: {prompt}") | |
| response = bot.ask(prompt, max_tokens=20, temperature=0.5, top_k=5, text_world_model=True) | |
| print(f"Response: {response}") |