AgGPT-14 / aggpt14.py
AGofficial's picture
Upload 11 files
f227bad verified
import re
import random
from collections import Counter, defaultdict
from training_data import corpus
from AGWM import *
ModelName = 'AgGPT-14'
def world_model(length =10):
"""Generates a simple world model for demonstration purposes."""
text_file = "training_data/WM.txt"
model_file = "AGWM.json"
if os.path.exists(model_file):
chain = load_model(model_file)
else:
chain = train_and_save_model(text_file, model_file)
return chain.generate(min_sentences=length)
class AgGPT14:
def __init__(self, corpus_text, order=3, seed=None):
assert order >= 1, "order must be >= 1"
self.model_name = ModelName
self.order = order
self.rng = random.Random(seed)
self.pairs = self._parse_pairs(corpus_text)
if not self.pairs:
raise ValueError("No (user, ai) pairs found in corpus.")
self.user_docs = [self._tokenize(u) for u, _ in self.pairs]
self.ai_docs = [self._tokenize(a) for _, a in self.pairs]
self.idf_weights = self._calculate_idf(self.user_docs)
self.global_transitions = self._build_global_transitions(self.ai_docs)
self.unigram = self._build_unigram(self.ai_docs)
self.user_ai_pairs = list(zip(self.user_docs, self.ai_docs))
def _calculate_idf(self, docs):
"""Calculates an aggressive IDF score to emphasize rare words."""
N = len(docs)
doc_freq = Counter()
for doc in docs:
for word in set(doc):
doc_freq[word] += 1
idf = {word: (N / (count + 1)) ** 2 for word, count in doc_freq.items()}
return idf
def _lcs(self, a, b):
"""Finds the Longest Common Subsequence between two lists of tokens."""
lengths = [[0 for j in range(len(b) + 1)] for i in range(len(a) + 1)]
for i, x in enumerate(a):
for j, y in enumerate(b):
if x == y:
lengths[i + 1][j + 1] = lengths[i][j] + 1
else:
lengths[i + 1][j + 1] = max(lengths[i + 1][j], lengths[i][j + 1])
result = []
x, y = len(a), len(b)
while x != 0 and y != 0:
if lengths[x][y] == lengths[x - 1][y]:
x -= 1
elif lengths[x][y] == lengths[x][y - 1]:
y -= 1
else:
result.append(a[x - 1])
x -= 1
y -= 1
return result[::-1]
def _parse_pairs(self, text):
pattern = re.compile(
r"user:\s*(.*?)\s*<pad>\s*ai:\s*(.*?)\s*<eos>",
re.DOTALL | re.IGNORECASE
)
pairs = []
for u, a in pattern.findall(text):
u, a = u.strip(), a.strip()
if u and a:
pairs.append((u, a))
return pairs
def _expand_contractions(self, s):
s = re.sub(r"what's", "what is", s)
s = re.sub(r"that's", "that is", s)
s = re.sub(r"it's", "it is", s)
s = re.sub(r"how's", "how is", s)
s = re.sub(r"he's", "he is", s)
s = re.sub(r"she's", "she is", s)
s = re.sub(r"you're", "you are", s)
s = re.sub(r"i'm", "i am", s)
s = re.sub(r"didn't", "did not", s)
s = re.sub(r"don't", "do not", s)
s = re.sub(r"can't", "cannot", s)
return s
def _tokenize(self, s):
s = s.strip().lower()
s = self._expand_contractions(s)
tokens = re.findall(r"[a-z]+(?:'[a-z]+)?|[?.!,;:]", s)
return [t for t in tokens if t]
def _with_bounds(self, tokens):
return ["<s>"] * self.order + tokens + ["</s>"]
def _similarity(self, query_tokens, doc_tokens):
if not query_tokens or not doc_tokens:
return 0.0
common_words = set(query_tokens).intersection(set(doc_tokens))
if not common_words:
return 0.0
idf_score = sum(self.idf_weights.get(word, 0.1) for word in common_words)
lcs = self._lcs(query_tokens, doc_tokens)
order_bonus_factor = 0.5
order_bonus = sum(self.idf_weights.get(word, 0.1) for word in lcs) * order_bonus_factor
return idf_score + order_bonus
def _find_best_match(self, user_text):
q_tokens = self._tokenize(user_text)
if not q_tokens:
return None
best_score = -1.0
best_idx = -1
for i, user_doc in enumerate(self.user_docs):
sim = self._similarity(q_tokens, user_doc)
if sim > best_score:
best_score = sim
best_idx = i
if best_idx == -1 or best_score < 0.1:
return None
return best_idx
def _build_global_transitions(self, docs):
trans = defaultdict(Counter)
for tokens in docs:
seq = self._with_bounds(tokens)
for i in range(len(seq) - self.order):
ctx = tuple(seq[i : i + self.order])
nxt = seq[i + self.order]
trans[ctx][nxt] += 1
return trans
def _build_unigram(self, docs):
uni = Counter()
for d in docs:
uni.update(d)
return uni
def _get_best_starting_context(self, user_text):
"""Finds the best match and deterministically returns its starting context."""
best_match_idx = self._find_best_match(user_text)
if best_match_idx is not None:
ai_doc = self.ai_docs[best_match_idx]
if len(ai_doc) >= self.order:
return tuple(ai_doc[:self.order])
return tuple(["<s>"] * self.order)
def _sample_next(self, context, temperature, top_k):
ctx = context
while len(ctx) > 0:
if ctx in self.global_transitions and self.global_transitions[ctx]:
counter = self.global_transitions[ctx]
break
ctx = ctx[1:]
else:
counter = Counter({k: v for k, v in self.unigram.items() if k not in ["<s>", "</s>"]})
if not counter: return "</s>"
items = sorted(counter.items(), key=lambda x: x[1], reverse=True)[:top_k]
if not items: return "</s>"
if temperature <= 0: return items[0][0]
tokens, weights = zip(*items)
scaled_weights = [w ** (1.0 / temperature) for w in weights]
return self.rng.choices(tokens, weights=scaled_weights, k=1)[0]
def _detokenize(self, tokens):
if not tokens: return ""
text = " ".join(t for t in tokens if t not in ["<s>", "</s>"])
text = re.sub(r'\s+([?.!,;:])', r'\1', text)
text = re.sub(r" ([']) ", r"\1", text)
if text: text = text[0].upper() + text[1:]
text = re.sub(r'([.!?]\s*)([a-z])', lambda m: m.group(1) + m.group(2).upper(), text)
text = re.sub(r'\bi\b', 'I', text)
return text
def respond(self, user_text, max_tokens=25, temperature=0.7, top_k=8, use_context_selection=True):
ctx = self._get_best_starting_context(user_text) if use_context_selection else tuple(["<s>"] * self.order)
out = list(ctx)
for _ in range(max_tokens):
nxt = self._sample_next(ctx, temperature, top_k)
if nxt == "</s>": break
out.append(nxt)
ctx = tuple(out[-self.order:])
return self._detokenize(out)
def ask(self, prompt, text_world_model=False, **kwargs):
"""User-friendly wrapper for the respond method."""
response = self.respond(prompt, **kwargs)
if text_world_model:
wm_response = world_model(length=10)
wm_response = "<world_model>" + wm_response + "</world_model>"
response = wm_response + " " + response
return response
def get_debug_info(self, user_text):
q_tokens = self._tokenize(user_text)
print(f"--- Debug info for: '{user_text}' ---")
print(f"Query Tokens (after normalization): {q_tokens}\n")
best_match_idx = self._find_best_match(user_text)
if best_match_idx is not None:
best_score = self._similarity(q_tokens, self.user_docs[best_match_idx])
print("Determined Best Match:")
print(f" - Corpus Entry: {' '.join(self.user_docs[best_match_idx])}")
print(f" - Score: {best_score:.2f}")
print(f" - Corresponding AI response will be used for context.")
else:
print("No suitable match found. Will use default starting context.")
if __name__ == "__main__":
print(f"Initializing model: {ModelName}")
bot = AgGPT14(corpus, order=3, seed=42)
print("\n=== Demonstrating the Fix for 'color' query ===")
bot.get_debug_info("what is your favorite color?")
print("\n=== Testing Model with Deterministic Matching ===")
tests = [
"hi",
"tell me a joke",
"do you have hobbies?",
"what is your favorite color?",
"thanks a lot",
]
for t in tests:
print(f"user: {t}")
response = bot.ask(t)
print(f"ai: {response}")
print("-" * 40)
print("====WORLD MODEL====")
print(world_model())
prompt = "hello, how are you?"
print(f"\nPrompt: {prompt}")
response = bot.ask(prompt, max_tokens=20, temperature=0.5, top_k=5, text_world_model=True)
print(f"Response: {response}")