| import argparse |
| import json |
| import math |
| import os |
| import re |
| import warnings |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
|
|
| from GPT_model import GPT, SimpleBPETokenizer as BPETokenizer, config_from_dict, DEFAULT_CONFIG |
|
|
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are Jarvis, a practical and calm AI assistant. " |
| "Give clear, structured answers with enough detail, examples, and step-by-step reasoning when helpful. " |
| "Stay natural and avoid unnecessary filler." |
| ) |
|
|
| TOPIC_KEYWORDS = { |
| "coding": { |
| "python", "script", "code", "debug", "bug", "traceback", "function", |
| "class", "api", "powershell", "terminal", "windows", "linux", |
| }, |
| "ml": { |
| "model", "train", "training", "dataset", "loss", "overfit", "overfitting", |
| "underfit", "epoch", "gradient", "batch", "tokenizer", "prompt", |
| }, |
| "food": { |
| "cook", "cooking", "recipe", "sandwich", "rice", "salad", "egg", "eggs", |
| "tea", "coffee", "lunch", "dinner", "breakfast", "meal", "snack", |
| }, |
| "productivity": { |
| "plan", "schedule", "focus", "habit", "routine", "confidence", "study", |
| "learn", "motivation", "discipline", "time", "goal", |
| }, |
| } |
|
|
| RETRIEVAL_TEMPLATE_MARKERS = { |
| "set a clear target", |
| "run one controlled test", |
| "compare before and after", |
| "keep only measurable improvements", |
| } |
|
|
| META_REPLY_MARKERS = { |
| "i can answer", |
| "tell me if you want", |
| "share your constraints and i will answer directly", |
| "ask and i will", |
| "tell me your exact goal", |
| } |
|
|
| UNSAFE_REQUEST_PATTERNS = [ |
| r"\b(how to|how do i|help me|ways to)\b.*\b(make|build|create|buy)\b.*\b(bomb|explosive|weapon)\b", |
| r"\b(how to|how do i|help me|ways to)\b.*\b(hack|ddos|phish|crack wifi|steal password|keylogger|malware|ransomware)\b", |
| r"\b(how to|how do i|help me|ways to)\b.*\b(kill|murder|poison)\b.*\b(person|people|human|someone|him|her)\b", |
| r"\b(how to|how do i|help me|ways to)\b.*\b(suicide|self harm|hurt myself)\b", |
| ] |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="CPU chat runner") |
| p.add_argument("--ckpt", default="cpu_gpt_jarvis_v6_guarded_best.pth") |
| p.add_argument("--temperature", type=float, default=0.45) |
| p.add_argument("--top-k", type=int, default=32) |
| p.add_argument("--top-p", type=float, default=0.90) |
| p.add_argument("--repetition-penalty", type=float, default=1.12) |
| p.add_argument("--no-repeat-ngram", type=int, default=3) |
| p.add_argument("--max-new-tokens", type=int, default=64) |
| p.add_argument("--min-new-tokens", type=int, default=12) |
| p.add_argument( |
| "--max-context-tokens", |
| type=int, |
| default=0, |
| help="Max context tokens. 0 uses the checkpoint/model block_size.", |
| ) |
| p.add_argument("--system-prompt", default=DEFAULT_SYSTEM_PROMPT) |
| p.add_argument("--ban-empty-tokens", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument("--threads", type=int, default=max(1, min(6, (os.cpu_count() or 4) - 2))) |
| p.add_argument("--interop-threads", type=int, default=1) |
| p.add_argument("--seed", type=int, default=1337) |
| p.add_argument("--int8", action=argparse.BooleanOptionalAction, default=False) |
| p.add_argument("--num-candidates", type=int, default=2) |
| p.add_argument("--safe-fallback", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument("--use-retrieval", action=argparse.BooleanOptionalAction, default=True) |
| p.add_argument("--retrieval-file", default=os.path.join("data", "jarvis_refine_train.txt")) |
| p.add_argument("--retrieval-file-general", default=os.path.join("data", "jarvis_mix_train.txt")) |
| p.add_argument("--retrieval-max-rows", type=int, default=4500) |
| return p.parse_args() |
|
|
|
|
| def load_tokenizer(): |
| tokenizer = BPETokenizer() |
| vocab_path = os.path.join(PROJECT_ROOT, "data", "bpe_vocab.json") |
| if not os.path.exists(vocab_path): |
| vocab_path = "bpe_vocab.json" |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| tokenizer.merges = { |
| tuple(map(int, k.split(","))): v for k, v in data["merges"].items() |
| } |
| tokenizer.vocab = {int(k): bytes(v, "latin1") for k, v in data["vocab"].items()} |
| tokenizer._encode_cached.cache_clear() |
| return tokenizer |
|
|
|
|
| def apply_top_p(logits, top_p): |
| if top_p is None or top_p >= 1.0: |
| return logits |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| probs = torch.softmax(sorted_logits, dim=-1) |
| cumprobs = torch.cumsum(probs, dim=-1) |
| mask = cumprobs > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| sorted_logits[mask] = -1e9 |
| out = torch.full_like(logits, -1e9) |
| out.scatter_(dim=-1, index=sorted_indices, src=sorted_logits) |
| return out |
|
|
|
|
| def collect_banned_token_ids(tokenizer, ban_empty_tokens): |
| if not ban_empty_tokens: |
| return [] |
|
|
| banned = [] |
| for token_id, token_bytes in tokenizer.vocab.items(): |
| decoded = token_bytes.decode("utf-8", errors="ignore") |
| if decoded == "": |
| banned.append(token_id) |
| return banned |
|
|
|
|
| def blocked_tokens_for_ngram(tokens, ngram_size): |
| if ngram_size is None or ngram_size <= 1: |
| return set() |
| if len(tokens) < ngram_size - 1: |
| return set() |
|
|
| prefix = tuple(tokens[-(ngram_size - 1) :]) |
| blocked = set() |
| limit = len(tokens) - ngram_size + 1 |
| for i in range(max(0, limit)): |
| if tuple(tokens[i : i + ngram_size - 1]) == prefix: |
| blocked.add(tokens[i + ngram_size - 1]) |
| return blocked |
|
|
|
|
| def cleanup_reply(text): |
| text = text.replace("\r", "") |
| if "\nUser:" in text: |
| text = text.split("\nUser:", 1)[0] |
|
|
| text = text.strip() |
| while text.startswith("Assistant:"): |
| text = text[len("Assistant:") :].lstrip() |
|
|
| text = re.sub(r"\n{3,}", "\n\n", text) |
| return text.strip() |
|
|
|
|
| def looks_valid_numeric_answer(text: str) -> bool: |
| t = text.strip().lower() |
| if not t: |
| return False |
| if re.search(r"\b\d+(?:\.\d+)?\s*%\s*of\s*\d+(?:\.\d+)?\b", t): |
| return True |
| if re.search(r"\b\d+(?:\.\d+)?\s*([+\-*/x])\s*\d+(?:\.\d+)?\s*=\s*-?\d", t): |
| return True |
| if re.search(r"\b\d+(?:\.\d+)?\s*(c|f)\s*=\s*-?\d", t): |
| return True |
| return False |
|
|
|
|
| def likely_gibberish(text: str) -> bool: |
| if not text or len(text.strip()) < 6: |
| return True |
| cleaned = text.strip() |
| if looks_valid_numeric_answer(cleaned): |
| return False |
| if re.search(r"(SCENE_|CHAR_|Dialogue_|emotion_|conflict_)", cleaned, flags=re.I): |
| return True |
| if cleaned.count("Assistant:") > 0 or cleaned.count("Context:") > 0: |
| return True |
| words = re.findall(r"[A-Za-z]{18,}", cleaned) |
| weird_words = [w for w in words if len(set(w.lower())) > 12] |
| if len(weird_words) >= 2: |
| return True |
| long_words = re.findall(r"\b[A-Za-z]{12,}\b", cleaned) |
| suspicious_long = 0 |
| for w in long_words: |
| vowels = sum(ch in "aeiouAEIOU" for ch in w) |
| vowel_ratio = vowels / max(1, len(w)) |
| if vowel_ratio < 0.28 or len(set(w.lower())) > 9: |
| suspicious_long += 1 |
| if suspicious_long >= 1: |
| return True |
| digit_count = sum(ch.isdigit() for ch in cleaned) |
| punct_count = sum(ch in "/\\|[]{}_=*#~`" for ch in cleaned) |
| if (digit_count + punct_count) > max(6, int(len(cleaned) * 0.2)): |
| return True |
| alpha = sum(ch.isalpha() for ch in cleaned) |
| printable = sum((31 < ord(ch) < 127) or ch in "\n\t\r" for ch in cleaned) |
| if printable < max(1, int(0.9 * len(cleaned))): |
| return True |
| if alpha < 6: |
| return True |
| return False |
|
|
|
|
| def response_quality_score(text: str) -> float: |
| t = text.strip() |
| if not t: |
| return -10.0 |
| score = 0.0 |
| if "\nUser:" in t or "\nAssistant:" in t: |
| score -= 2.0 |
| if likely_gibberish(t): |
| score -= 5.0 |
| if len(t) >= 18: |
| score += 1.5 |
| word_count = len(re.findall(r"[A-Za-z]+", t)) |
| if word_count < 5: |
| score -= 3.0 |
| if t.count(" ") < 2: |
| score -= 2.0 |
| if len(t) > 600: |
| score -= 1.0 |
| if re.search(r"[.!?]$", t): |
| score += 0.5 |
| if re.search(r"\b\d{4,}\b", t): |
| score -= 0.5 |
| if t.count("- ") >= 5: |
| score -= 0.5 |
| return score |
|
|
|
|
| def normalize_token(token: str) -> str: |
| t = token.lower().strip().strip("'") |
| typo_map = { |
| "sandwitch": "sandwich", |
| "sandwhich": "sandwich", |
| "recipie": "recipe", |
| "recepie": "recipe", |
| } |
| t = typo_map.get(t, t) |
| if t.endswith("'s"): |
| t = t[:-2] |
| if len(t) > 5 and t.endswith("ing"): |
| t = t[:-3] |
| elif len(t) > 4 and t.endswith("ed"): |
| t = t[:-2] |
| elif len(t) > 4 and t.endswith("ies"): |
| t = t[:-3] + "y" |
| elif len(t) > 4 and t.endswith("es"): |
| t = t[:-2] |
| elif len(t) > 4 and t.endswith("s") and not t.endswith(("ss", "us")): |
| t = t[:-1] |
| return t |
|
|
|
|
| def normalize_for_retrieval(text: str) -> List[str]: |
| words = re.findall(r"[a-zA-Z0-9+']+", text.lower()) |
| stop = { |
| "the", "a", "an", "is", "are", "to", "and", "or", "for", "of", "in", |
| "on", "with", "me", "you", "i", "it", "this", "that", "my", "your", |
| "be", "can", "do", "how", "what", "why", "when", "where", "who", |
| "should", "would", "could", "please", "tell", "about", "from", "into", |
| "have", "has", "had", "will", "just", "need", "want", "like", "than", |
| "there", "their", "them", "then", "also", "only", "very", "much", |
| } |
| filtered = [] |
| for raw in words: |
| if raw.isdigit(): |
| continue |
| tok = normalize_token(raw) |
| if len(tok) < 3 or tok in stop: |
| continue |
| filtered.append(tok) |
| return filtered |
|
|
|
|
| def infer_topics(tokens: Set[str]) -> Set[str]: |
| topics = set() |
| for topic, keywords in TOPIC_KEYWORDS.items(): |
| if tokens & keywords: |
| topics.add(topic) |
| return topics |
|
|
|
|
| def extract_numeric_tokens(text: str) -> Set[str]: |
| return set(re.findall(r"\b\d+\b", text)) |
|
|
|
|
| def stable_variant_index(text: str, count: int) -> int: |
| if count <= 1: |
| return 0 |
| seed = 0 |
| for i, ch in enumerate(text.lower()): |
| seed += (i + 1) * ord(ch) |
| return seed % count |
|
|
|
|
| def canonical_reply(text: str) -> str: |
| return re.sub(r"[^a-z0-9]+", " ", text.lower()).strip() |
|
|
|
|
| def normalize_user_text(text: str) -> str: |
| out = re.sub(r"\s+", " ", text.lower().strip()) |
| replacements = { |
| "sandwitch": "sandwich", |
| "sandwhich": "sandwich", |
| "recipie": "recipe", |
| "recepie": "recipe", |
| "pls": "please", |
| "plz": "please", |
| "u": "you", |
| "luv": "love", |
| } |
| for src, dst in replacements.items(): |
| out = re.sub(rf"\b{re.escape(src)}\b", dst, out) |
| return out |
|
|
|
|
| def looks_noisy_help_request(user: str) -> bool: |
| u = normalize_user_text(user) |
| if "help" not in u: |
| return False |
| if re.search(r"[bcdfghjklmnpqrstvwxyz]{8,}", u): |
| return True |
| words = re.findall(r"[a-z]+", u) |
| weird = 0 |
| for w in words: |
| if len(w) < 10: |
| continue |
| vowels = sum(ch in "aeiou" for ch in w) |
| vowel_ratio = vowels / max(1, len(w)) |
| if vowel_ratio < 0.22 or len(set(w)) > 9: |
| weird += 1 |
| return weird >= 1 |
|
|
|
|
| def format_number(x: float) -> str: |
| if abs(x - round(x)) < 1e-9: |
| return str(int(round(x))) |
| return f"{x:.6f}".rstrip("0").rstrip(".") |
|
|
|
|
| def try_simple_math_reply(user: str) -> Optional[str]: |
| u = normalize_user_text(user) |
|
|
| percent = re.search(r"\b(-?\d+(?:\.\d+)?)\s*%\s*of\s*(-?\d+(?:\.\d+)?)\b", u) |
| if percent: |
| a = float(percent.group(1)) |
| b = float(percent.group(2)) |
| result = (a / 100.0) * b |
| return f"{format_number(a)}% of {format_number(b)} is {format_number(result)}." |
|
|
| basic = re.search(r"\b(-?\d+(?:\.\d+)?)\s*([+\-*/x])\s*(-?\d+(?:\.\d+)?)\b", u) |
| if basic: |
| a = float(basic.group(1)) |
| op = basic.group(2) |
| b = float(basic.group(3)) |
| if op in {"*", "x"}: |
| result = a * b |
| elif op == "/": |
| if abs(b) < 1e-12: |
| return "Division by zero is undefined." |
| result = a / b |
| elif op == "+": |
| result = a + b |
| else: |
| result = a - b |
| symbol = "x" if op == "*" else op |
| return f"{format_number(a)} {symbol} {format_number(b)} = {format_number(result)}." |
| return None |
|
|
|
|
| def looks_meta_reply(text: str) -> bool: |
| low = text.lower() |
| return any(marker in low for marker in META_REPLY_MARKERS) |
|
|
|
|
| def unsafe_request_reply(user: str) -> Optional[str]: |
| low = re.sub(r"\s+", " ", user.lower().strip()) |
| for pattern in UNSAFE_REQUEST_PATTERNS: |
| if re.search(pattern, low): |
| return ( |
| "I cannot help with harmful or illegal actions. " |
| "I can help with safety, prevention, or legal alternatives." |
| ) |
| return None |
|
|
|
|
| def definition_stub(topic: str, topics: Set[str]) -> str: |
| t = topic.strip().strip(".!?") |
| low = t.lower() |
| if "api" in low: |
| return "An API is a defined interface that lets one program communicate with another." |
| if "recursion" in low: |
| return "Recursion is when a function solves a problem by calling itself on a smaller case." |
| if "machine learning" in low: |
| return "Machine learning is training models on data so they can predict or classify new examples." |
| if "photosynthesis" in low: |
| return "Photosynthesis is how plants use sunlight, water, and carbon dioxide to produce food." |
| if {"coding", "ml"} & topics: |
| return f"{t} is a software or ML concept best learned by definition, one example, and one practical use-case." |
| if "food" in topics: |
| return f"{t} is best understood through ingredients, method, and timing." |
| if "productivity" in topics: |
| return f"{t} is a practical habit system: clear goal, consistent action, and measurable review." |
| return f"{t} is best understood as what it is, why it matters, and one practical example." |
|
|
|
|
| def practical_default_answer(user: str) -> str: |
| cleaned = re.sub(r"\s+", " ", user).strip()[:120] |
| tokens = set(normalize_for_retrieval(cleaned)) |
| topics = infer_topics(tokens) |
| if {"coding", "ml"} & topics: |
| return ( |
| f"For '{cleaned}', a solid path is: 1) make a minimal reproducible example, " |
| "2) inspect the exact error or mismatch, 3) change one thing at a time, " |
| "4) keep the change that measurably improves the result." |
| ) |
| if "food" in topics: |
| return ( |
| f"For '{cleaned}', think in three steps: prep ingredients, cook in short controlled stages, " |
| "then taste and adjust seasoning at the end." |
| ) |
| if "productivity" in topics: |
| return ( |
| f"For '{cleaned}', use a simple loop: choose one measurable goal, do one focused block of work, " |
| "then review what actually changed." |
| ) |
| return ( |
| f"For '{cleaned}', break it into: 1) what you want to achieve, 2) the main constraints, " |
| "and 3) one concrete next action you can take right now." |
| ) |
|
|
|
|
| def polish_reply(text: str, max_chars: int = 800) -> str: |
| out = cleanup_reply(text) |
| out = re.sub(r"\s+([,.!?])", r"\1", out) |
| out = re.sub(r"\s+", " ", out).strip() |
| if len(out) > max_chars: |
| short = out[:max_chars].rsplit(" ", 1)[0].strip() |
| out = (short if short else out[:max_chars]).rstrip(" ,;:") + "..." |
| if out and out[-1] not in ".!?": |
| out += "." |
| return out |
|
|
|
|
| def finalize_reply( |
| user: str, |
| reply: str, |
| last_reply_signature: str, |
| safe_fallback: bool = True, |
| allow_repeat: bool = False, |
| ) -> str: |
| candidate = cleanup_reply(reply or "") |
| if not candidate: |
| candidate = practical_default_answer(user) |
|
|
| if likely_gibberish(candidate) or looks_meta_reply(candidate): |
| alt = heuristic_answer(user) |
| if alt and (not likely_gibberish(alt)) and (not looks_meta_reply(alt)): |
| candidate = alt |
| elif safe_fallback: |
| candidate = practical_default_answer(user) |
|
|
| candidate = polish_reply(candidate) |
| if safe_fallback and (not allow_repeat) and canonical_reply(candidate) == last_reply_signature: |
| alt = polish_reply(generic_fallback_reply(user, variant_offset=5)) |
| if canonical_reply(alt) != last_reply_signature: |
| candidate = alt |
| return candidate |
|
|
|
|
| def safe_rule_reply(user: str) -> Optional[str]: |
| u = normalize_user_text(user) |
| u_fixed = u |
|
|
| if re.search(r"\bwho\s+(made|created|built)\s+you\b", u_fixed): |
| return "You did. This local Jarvis model was built and trained in your project on your laptop." |
| if "why made you" in u_fixed or re.search(r"\bwhy\b.*\b(made|created|built)\b.*\byou\b", u_fixed): |
| return "I was made to be your practical offline assistant for coding, learning, and everyday tasks." |
|
|
| if looks_noisy_help_request(user) or ("crazy" in u_fixed and "help" in u_fixed): |
| return ( |
| "I hear you. Take one slow breath. Tell me one thing that is going wrong right now, " |
| "and I will give one clear next step." |
| ) |
| if re.search(r"\b(i am|im|i feel)\s+(crazy|overwhelmed|stressed)\b", u_fixed): |
| return ( |
| "I hear you. Take one slow breath. Tell me one thing that is going wrong right now, " |
| "and I will give one clear next step." |
| ) |
| if re.search(r"\bi (love|really love|like) you\b", u_fixed): |
| return "Love you too. I am here for you. Tell me one thing you want help with right now." |
| if re.search(r"\b(example|sample)\b.*\bcountry\b", u_fixed): |
| return "Example countries: Japan, Brazil, Canada, Egypt, and Norway." |
| if re.search(r"\b(example|sample)\b.*\bfruit\b", u_fixed): |
| return "Example fruits: apple, banana, mango, orange, and grapes." |
| if re.search(r"\b(example|sample)\b.*\bcit(y|ies)\b", u_fixed): |
| return "Example cities: Tokyo, Paris, Cairo, Toronto, and Sao Paulo." |
| if "todo list" in u_fixed or "to-do list" in u_fixed: |
| return ( |
| "Simple to-do template: 1) top priority, 2) second priority, 3) quick task under 10 minutes, " |
| "4) deadline, 5) done check." |
| ) |
| if "daily routine" in u_fixed or "morning routine" in u_fixed: |
| return ( |
| "Daily routine template: fixed wake time, one focused work block, one exercise block, " |
| "and a short evening review." |
| ) |
| math_reply = try_simple_math_reply(u_fixed) |
| if math_reply: |
| return math_reply |
| if "sandwich" in u_fixed and any(k in u_fixed for k in ["recipe", "make", "how to", "how do i"]): |
| return ( |
| "Simple sandwich recipe: 1) toast or warm bread, 2) add protein (egg/chicken/cheese), " |
| "3) add vegetables and sauce, 4) close, cut, and serve." |
| ) |
| if "how do i make a sandwich" in u_fixed or "make a sandwich" in u_fixed: |
| return "Basic sandwich: toast bread, add protein, add vegetables, add sauce, close, and cut." |
| if "how do i cook rice" in u or "cook rice" in u: |
| return "Rinse rice, use 1 cup rice to 2 cups water, simmer covered 12 to 15 minutes, then rest 5 minutes." |
| if "how do i make tea" in u or "make tea" in u: |
| return "Boil water, steep tea for 3 to 5 minutes, remove tea, then add milk, lemon, or honey if needed." |
| if ("kill" in u and "process" in u) and ("python" in u or "windows" in u or "powershell" in u): |
| return ( |
| "On Windows PowerShell, list processes with `Get-Process python` and stop one with " |
| "`Stop-Process -Id <PID> -Force`." |
| ) |
| if "who are you" in u or "who u" in u or "what are you" in u: |
| return "I am Jarvis, your practical offline assistant for coding and daily tasks." |
| if "what can you do" in u or "what are your features" in u: |
| return "I can help with coding, debugging, learning plans, everyday how-to questions, and task planning." |
| if "can you keep answers short" in u or ("answers" in u and "short" in u): |
| return "Yes. I default to concise, actionable replies." |
| if "how should i ask for help" in u: |
| return "Share your goal, relevant code, exact error, and constraints like time or hardware." |
| if any( |
| u == g or u.startswith(g + " ") |
| for g in ["hi", "hello", "hey", "yo", "greetings", "good morning", "good afternoon", "good evening", "sup", "what's up"] |
| ): |
| return "Hi. Give me one specific question and I will answer directly." |
| minutes_match = re.search(r"\b(\d+)\s*[- ]?minutes?\b", u) |
| if minutes_match and ("plan" in u or "schedule" in u): |
| mins = minutes_match.group(1) |
| return f"Use {mins} minutes as: 10% planning, 75% execution, and 15% review with one concrete next action." |
| if "build confidence" in u or ("confidence" in u and ("how" in u or "improve" in u)): |
| return ( |
| "Build confidence with a 7-day loop: 1) one small daily challenge, 2) log one win per day, " |
| "3) review proof of progress weekly." |
| ) |
| if "astronomy" in u: |
| return ( |
| "Astronomy basics: stars, planets, gravity, and light. " |
| "Start with the solar system, then learn how telescopes observe distant objects." |
| ) |
| if "discipline" in u and ("improve" in u or "build" in u): |
| return ( |
| "Improve discipline with one fixed daily routine: same start time, one priority task first, " |
| "and a simple completion tracker." |
| ) |
| if "apology" in u and ("message" in u or "email" in u): |
| return ( |
| "Template: 'Sorry for the delay. I should have replied sooner. " |
| "Here is the update: <one clear status line>. Next step: <specific action and date>.'" |
| ) |
| if ("code works locally" in u and "ci" in u) or ("fails in ci" in u) or ("fail" in u and "ci" in u): |
| return ( |
| "CI debug checklist: lock dependency versions, match Python/OS versions, print env vars, " |
| "run tests with the same command as CI, then diff failing logs." |
| ) |
| if "recursion" in u: |
| return ( |
| "Recursion means a function calls itself on a smaller version of the same problem " |
| "until it reaches a base case that stops." |
| ) |
| if "30 minute" in u or "30 minutes" in u: |
| return "Use 30 minutes as: 3 minutes plan, 22 minutes focused work, 5 minutes review and next action." |
| if "c++" in u and "python" in u and ("learn" in u or "first" in u): |
| return "Start with Python first for faster progress, then add C++ when you need performance or low-level control." |
| if "why" in u and ("overfit" in u or "overfitting" in u): |
| return ( |
| "Overfitting usually means the model learned training details instead of general patterns. " |
| "Common causes: too little diverse data, too many training steps, or model capacity too high." |
| ) |
| if "traceback" in u or "error" in u or "bug" in u: |
| return ( |
| "Debug order: 1) paste full traceback, 2) show failing code block, " |
| "3) state expected behavior, 4) list recent changes." |
| ) |
| if "machine learning" in u: |
| return ( |
| "Machine learning is training a model from examples to make predictions. " |
| "Workflow: clean data, train, validate on unseen data, then iterate." |
| ) |
| if "learn python" in u: |
| return "Learn Python with a loop: basics, short scripts, one mini project weekly, then error-driven practice." |
| if "favorite color" in u: |
| return "I do not have personal preferences, but I can help pick colors for your project." |
| if "favorite movie" in u: |
| return "I do not have favorites, but I can suggest movies by genre and mood." |
| if "what should i eat" in u or "lunch" in u: |
| return "Quick lunch: protein + carbs + vegetables. Example: egg sandwich, fruit, and yogurt." |
| if "overfitting" in u: |
| return "Overfitting means the model memorizes training data and performs worse on new data." |
| if "dataloader" in u or ("optimize" in u and "cpu" in u): |
| return ( |
| "For CPU data loading: pre-tokenize once, keep tensors contiguous, avoid heavy __getitem__ logic, " |
| "and reduce Python overhead per step." |
| ) |
| if "training loop" in u or ("cleaner" in u and "train" in u): |
| return "Use: zero grad, forward, loss, backward, clip, step, log; keep eval and checkpoints in helpers." |
| if ("help" in u and "code" in u) or "coding help" in u: |
| return "Paste the code and error, and I will give a direct fix plus a cleaner version." |
| if "train" in u and "model" in u: |
| return ( |
| "For CPU training: keep model compact, clean duplicate-heavy data, train in stages, " |
| "and validate every 100 steps." |
| ) |
| if "plan" in u: |
| return "Plan: define one goal, do one focused block, test output, then do one short review pass." |
| return None |
|
|
|
|
| def generic_fallback_reply(user: str, variant_offset: int = 0) -> str: |
| cleaned = re.sub(r"\s+", " ", user).strip()[:120] |
| tokens = set(normalize_for_retrieval(cleaned)) |
| topics = infer_topics(tokens) |
| if {"coding", "ml"} & topics: |
| variants = [ |
| f"For '{cleaned}', start with a minimal example, print key variables around the bug, " |
| "and compare current vs expected output line by line.", |
| f"To tackle '{cleaned}', first isolate one failing case, then change only one input, setting, or line of code at a time.", |
| f"For '{cleaned}', write down the exact error message, locate the line that triggers it, and reason from inputs to outputs step by step.", |
| ] |
| elif "food" in topics: |
| variants = [ |
| f"For '{cleaned}', choose a base (rice, pasta, bread), add one protein, and finish with vegetables plus a simple sauce.", |
| f"For '{cleaned}', keep it simple: short prep, medium heat, and one final taste-and-adjust step before serving.", |
| f"For '{cleaned}', decide on cooking time, pick ingredients that fit that window, and avoid more than three main steps.", |
| ] |
| elif "productivity" in topics: |
| variants = [ |
| f"For '{cleaned}', define one daily action under 20 minutes that moves you forward and track it for a week.", |
| f"For '{cleaned}', use a simple routine: same start time, one clear task, and a 2-minute review at the end.", |
| f"To improve '{cleaned}', pick one metric you can count, one habit that affects it, and review progress every few days.", |
| ] |
| else: |
| variants = [ |
| f"For '{cleaned}', think in three layers: simple explanation, key reasons it matters, and one example from daily life.", |
| f"To handle '{cleaned}', decide what success looks like, list three small steps toward it, and start with the easiest.", |
| f"For '{cleaned}', write down your goal in one sentence, then list obstacles and how you will handle each one.", |
| ] |
| idx = (stable_variant_index(cleaned, len(variants)) + variant_offset) % len(variants) |
| return variants[idx] |
|
|
|
|
| def heuristic_answer(user: str) -> Optional[str]: |
| cleaned = re.sub(r"\s+", " ", user).strip().rstrip("?") |
| lower = cleaned.lower() |
| lower_norm = normalize_user_text(lower) |
| tokens = set(normalize_for_retrieval(cleaned)) |
| topics = infer_topics(tokens) |
|
|
| if looks_noisy_help_request(cleaned): |
| return ( |
| "I can help. First, send one short sentence about the main problem, " |
| "then I will give one direct fix." |
| ) |
| if re.search(r"\bi (love|really love|like) you\b", lower_norm): |
| return "Appreciate it. I am with you. What should we fix or build next?" |
|
|
| if lower_norm.startswith("give me an example of ") or lower_norm.startswith("give me example of "): |
| topic = re.sub(r"^give me (an )?example of\s+", "", lower_norm, flags=re.I).strip() |
| if "city" in topic: |
| return "Example cities: Tokyo, Paris, Cairo, Toronto, and Sao Paulo." |
| return f"Example of {topic}: start with one simple real-world case, then expand from there." |
|
|
| if "recipe" in lower_norm and "sandwich" in lower_norm: |
| return ( |
| "Simple sandwich recipe: 1) toast bread, 2) add protein, 3) add vegetables and sauce, " |
| "4) close and cut." |
| ) |
|
|
| if lower_norm.startswith("how do i ") or lower_norm.startswith("how to ") or lower_norm.startswith("how can i "): |
| task = re.sub(r"^(how do i|how to|how can i)\s+", "", cleaned, flags=re.I).strip() |
| task = ( |
| task.replace("sandwitch", "sandwich") |
| .replace("sandwhich", "sandwich") |
| .replace("recipie", "recipe") |
| .replace("recepie", "recipe") |
| ) |
| mins_match = re.search(r"\b(\d+)\s*minutes?\b", lower) |
| mins = mins_match.group(1) if mins_match else None |
| if "food" in topics: |
| if mins: |
| return ( |
| f"Quick way to {task} in {mins} minutes: 1) prep ingredients first, " |
| "2) cook on medium heat in short stages, 3) taste and finish." |
| ) |
| return f"Quick way to {task}: 1) prep ingredients, 2) cook with medium heat, 3) taste and adjust, 4) serve." |
| if {"coding", "ml"} & topics: |
| return ( |
| f"To {task}: 1) reproduce on a minimal example, 2) change one variable at a time, " |
| "3) measure before/after, 4) keep only changes that improve results." |
| ) |
| return f"To {task}: set a clear outcome, split into 3 short steps, do step one now, then verify." |
|
|
| if lower.startswith("tell me about "): |
| topic = re.sub(r"^tell me about\s+", "", cleaned, flags=re.I).strip() |
| return definition_stub(topic, topics) |
|
|
| if lower.startswith("what is "): |
| topic = cleaned[8:].strip() |
| return definition_stub(topic, topics) |
|
|
| if lower.startswith("can you explain "): |
| topic = re.sub(r"^can you explain\s+", "", cleaned, flags=re.I).strip() |
| return definition_stub(topic, topics) |
| if lower.startswith("explain "): |
| topic = re.sub(r"^explain\s+", "", cleaned, flags=re.I).strip() |
| return definition_stub(topic, topics) |
|
|
| if lower.startswith("can you "): |
| ask = cleaned[8:].strip() |
| return f"Yes. For '{ask}', give me one concrete constraint and I will provide direct steps." |
|
|
| if lower.startswith("why "): |
| if "ml" in topics: |
| return "Likely cause is a mismatch between data quality, model size, and training length. Share metrics and I will isolate it." |
| return "Usually there is one root cause and a few contributors. Share context and I will break down cause, effect, and fix." |
|
|
| if lower.startswith("should i "): |
| decision = cleaned[9:].strip() |
| return f"For '{decision}', compare effort, risk, and payoff. I can give a direct recommendation if you share your constraints." |
|
|
| if re.search(r"\b\d+\s*minutes?\b", lower): |
| mins_match = re.search(r"\b(\d+)\s*minutes?\b", lower) |
| mins = mins_match.group(1) if mins_match else "30" |
| return f"Use {mins} minutes with 10% planning, 75% execution, and 15% review so you finish with a next action." |
|
|
| if {"coding", "ml"} & topics: |
| return ( |
| f"For '{cleaned}', use this: 1) reproduce once, 2) isolate one variable, " |
| "3) patch, 4) retest the same case." |
| ) |
| if "food" in topics: |
| return f"For '{cleaned}', start with ingredients, then 3 cooking steps, then timing adjustments." |
| if "productivity" in topics: |
| return f"For '{cleaned}', define one daily action, one trigger, and one progress check." |
| if len(cleaned.split()) >= 4: |
| return practical_default_answer(cleaned) |
| return None |
|
|
|
|
| def load_retrieval_bank(path: str, max_rows: int): |
| if not os.path.exists(path): |
| return [] |
| text = open(path, "r", encoding="utf-8", errors="ignore").read() |
| pairs = re.findall(r"User:\s*(.*?)\nAssistant:\s*(.*?)(?=\n\nUser:|\Z)", text, flags=re.S) |
| bank = [] |
| for user, assistant in pairs[: max_rows]: |
| u = re.sub(r"\s+", " ", user).strip() |
| a = re.sub(r"\s+", " ", assistant).strip() |
| if len(u) < 4 or len(a) < 8: |
| continue |
| user_tokens_seq = normalize_for_retrieval(u) |
| user_tokens = set(user_tokens_seq) |
| if not user_tokens: |
| continue |
| answer_tokens = set(normalize_for_retrieval(a)) |
| marker_hits = sum(marker in a.lower() for marker in RETRIEVAL_TEMPLATE_MARKERS) |
| bank.append( |
| { |
| "user": u, |
| "assistant": a, |
| "user_tokens": user_tokens, |
| "answer_tokens": answer_tokens, |
| "answer_words": len(re.findall(r"[A-Za-z0-9']+", a)), |
| "user_bigrams": set(zip(user_tokens_seq, user_tokens_seq[1:])), |
| "numbers": extract_numeric_tokens(u), |
| "topics": infer_topics(user_tokens | answer_tokens), |
| "is_template": marker_hits >= 2, |
| } |
| ) |
| return bank |
|
|
|
|
| def merge_retrieval_banks(*banks): |
| seen = set() |
| merged = [] |
| for bank in banks: |
| for row in bank: |
| key = (row["user"].lower(), row["assistant"].lower()) |
| if key in seen: |
| continue |
| seen.add(key) |
| merged.append(row) |
| return merged |
|
|
|
|
| def build_retrieval_idf(bank: List[dict]) -> Dict[str, float]: |
| if not bank: |
| return {} |
| df = {} |
| for row in bank: |
| for tok in row["user_tokens"] | row["answer_tokens"]: |
| df[tok] = df.get(tok, 0) + 1 |
| total = max(1, len(bank)) |
| return {tok: math.log((1 + total) / (1 + freq)) + 1.0 for tok, freq in df.items()} |
|
|
|
|
| def weighted_overlap(query_tokens: Set[str], target_tokens: Set[str], idf: Dict[str, float]) -> float: |
| if not query_tokens or not target_tokens: |
| return 0.0 |
| numer = 0.0 |
| denom = 0.0 |
| for tok in query_tokens: |
| weight = idf.get(tok, 1.0) |
| denom += weight |
| if tok in target_tokens: |
| numer += weight |
| return numer / max(1e-9, denom) |
|
|
|
|
| def topic_alignment_bonus(query_topics: Set[str], row_topics: Set[str]) -> float: |
| if not query_topics or not row_topics: |
| return 0.0 |
| overlap = len(query_topics & row_topics) |
| if overlap == 0: |
| return -0.16 |
| if overlap >= 2: |
| return 0.10 |
| return 0.05 |
|
|
|
|
| def retrieve_reply(user: str, bank: List[dict], idf: Dict[str, float]) -> Optional[str]: |
| if not bank: |
| return None |
|
|
| query_tokens_seq = normalize_for_retrieval(user) |
| query_tokens = set(query_tokens_seq) |
| if not query_tokens: |
| return None |
| query_bigrams = set(zip(query_tokens_seq, query_tokens_seq[1:])) |
| query_topics = infer_topics(query_tokens) |
| query_numbers = extract_numeric_tokens(user) |
|
|
| best_score = -1.0 |
| second_score = -1.0 |
| best_answer = None |
| for row in bank: |
| overlap_tokens = query_tokens & row["user_tokens"] |
| if not overlap_tokens: |
| continue |
| if len(query_tokens) >= 4 and len(overlap_tokens) < 2: |
| continue |
|
|
| score_user = weighted_overlap(query_tokens, row["user_tokens"], idf) |
| score_answer = weighted_overlap(query_tokens, row["answer_tokens"], idf) |
| score_bigrams = len(query_bigrams & row["user_bigrams"]) / max(1, len(query_bigrams)) |
| rare_overlap = weighted_overlap(query_tokens, overlap_tokens, idf) |
| topic_bonus = topic_alignment_bonus(query_topics, row["topics"]) |
| template_penalty = 0.12 if row["is_template"] else 0.0 |
| length_penalty = max(0.0, (len(row["assistant"]) - 240) / 240.0) * 0.08 |
| short_penalty = 0.10 if row["answer_words"] < 5 and len(query_tokens) >= 3 else 0.0 |
| number_penalty = 0.0 |
| if query_numbers and row["numbers"] and not (query_numbers & row["numbers"]): |
| number_penalty = 0.12 |
|
|
| score = ( |
| 0.45 * score_user |
| + 0.18 * score_answer |
| + 0.18 * score_bigrams |
| + 0.19 * rare_overlap |
| + topic_bonus |
| - template_penalty |
| - length_penalty |
| - short_penalty |
| - number_penalty |
| ) |
| if score > best_score: |
| second_score = best_score |
| best_score = score |
| best_answer = row["assistant"] |
| elif score > second_score: |
| second_score = score |
|
|
| if not best_answer: |
| return None |
| min_threshold = 0.52 if len(query_tokens) >= 3 else 0.62 |
| if best_score < min_threshold: |
| return None |
| if second_score > 0 and (best_score - second_score) < 0.06 and best_score < 0.72: |
| return None |
| if likely_gibberish(best_answer): |
| return None |
| return best_answer |
|
|
|
|
| @torch.inference_mode() |
| def generate( |
| model, |
| tokenizer, |
| prompt_tokens, |
| max_new_tokens, |
| min_new_tokens, |
| temperature, |
| top_k, |
| top_p, |
| repetition_penalty, |
| no_repeat_ngram, |
| max_context_tokens, |
| banned_token_ids, |
| model_block_size: int, |
| ): |
| prompt_tokens = prompt_tokens[-max_context_tokens:] |
| idx = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu").unsqueeze(0) |
| generated = [] |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -model_block_size:] |
| logits, _ = model(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-6) |
| logits = torch.nan_to_num(logits, nan=-1e9, posinf=-1e9, neginf=-1e9) |
|
|
| if banned_token_ids: |
| logits[:, banned_token_ids] = -1e9 |
|
|
| if repetition_penalty and repetition_penalty > 1.0: |
| recent = idx[0, -96:].tolist() |
| for token_id in set(recent): |
| token_logit = logits[0, token_id] |
| logits[0, token_id] = ( |
| token_logit / repetition_penalty if token_logit >= 0 else token_logit * repetition_penalty |
| ) |
|
|
| blocked = blocked_tokens_for_ngram(idx[0].tolist(), no_repeat_ngram) |
| if blocked: |
| logits[0, list(blocked)] = -1e9 |
|
|
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -1e9 |
|
|
| logits = apply_top_p(logits, top_p) |
|
|
| if torch.all(logits < -1e8): |
| logits = torch.zeros_like(logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) |
| probs = probs.clamp(min=1e-9) |
| probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
| idx_next = torch.multinomial(probs, 1) |
| idx = torch.cat([idx, idx_next], dim=1) |
| generated.append(int(idx_next.item())) |
|
|
| if len(generated) >= min_new_tokens: |
| partial = tokenizer.decode(generated) |
| if "\nUser:" in partial: |
| break |
|
|
| reply = cleanup_reply(tokenizer.decode(generated)) |
| return reply, generated |
|
|
|
|
| def generate_best_of_n( |
| model, |
| tokenizer, |
| prompt_tokens: List[int], |
| max_new_tokens: int, |
| min_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| no_repeat_ngram: int, |
| max_context_tokens: int, |
| banned_token_ids: List[int], |
| num_candidates: int, |
| model_block_size: int, |
| ) -> Tuple[str, List[int]]: |
| schedules = [0.45, 0.55, 0.65, 0.75] |
| candidates = [] |
| for i in range(max(1, num_candidates)): |
| t = schedules[i % len(schedules)] |
| t = max(0.35, min(0.85, 0.5 * temperature + 0.5 * t)) |
| reply, generated = generate( |
| model=model, |
| tokenizer=tokenizer, |
| prompt_tokens=prompt_tokens, |
| max_new_tokens=max_new_tokens, |
| min_new_tokens=min_new_tokens, |
| temperature=t, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| no_repeat_ngram=no_repeat_ngram, |
| max_context_tokens=max_context_tokens, |
| banned_token_ids=banned_token_ids, |
| model_block_size=model_block_size, |
| ) |
| candidates.append((response_quality_score(reply), reply, generated)) |
|
|
| candidates.sort(key=lambda x: x[0], reverse=True) |
| best_score, best_reply, best_generated = candidates[0] |
| if best_score < 0.25: |
| return "", [] |
| return best_reply, best_generated |
|
|
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| torch.set_num_threads(args.threads) |
| torch.set_num_interop_threads(args.interop_threads) |
|
|
| tokenizer = load_tokenizer() |
| vocab_size = len(tokenizer.vocab) |
| print("Vocab size:", vocab_size) |
|
|
| if not os.path.exists(args.ckpt): |
| models_ckpt = os.path.join(PROJECT_ROOT, "Models", args.ckpt) |
| if os.path.exists(models_ckpt): |
| args.ckpt = models_ckpt |
|
|
| if not os.path.exists(args.ckpt): |
| fallback_v5 = os.path.join(PROJECT_ROOT, "Models", "cpu_gpt_jarvis_v5_guarded_best.pth") |
| fallback_v4 = os.path.join(PROJECT_ROOT, "Models", "cpu_gpt_jarvis_v4_mix_best.pth") |
| fallback_rebuild = os.path.join(PROJECT_ROOT, "Models", "cpu_gpt_jarvis_rebuild_l6_v2048_best.pth") |
| fallback_ckpt = os.path.join(PROJECT_ROOT, "Models", "cpu_gpt_jarvis_godmode_l6_v2048_best.pth") |
| if args.ckpt == "cpu_gpt_jarvis_v6_guarded_best.pth" and os.path.exists(fallback_v5): |
| print(f"Checkpoint not found: {args.ckpt}") |
| print(f"Falling back to: {fallback_v5}") |
| args.ckpt = fallback_v5 |
| elif args.ckpt == "cpu_gpt_jarvis_v5_guarded_best.pth" and os.path.exists(fallback_v4): |
| print(f"Checkpoint not found: {args.ckpt}") |
| print(f"Falling back to: {fallback_v4}") |
| args.ckpt = fallback_v4 |
| elif args.ckpt == "cpu_gpt_jarvis_v4_mix_best.pth" and os.path.exists(fallback_rebuild): |
| print(f"Checkpoint not found: {args.ckpt}") |
| print(f"Falling back to: {fallback_rebuild}") |
| args.ckpt = fallback_rebuild |
| elif args.ckpt == "cpu_gpt_jarvis_rebuild_l6_v2048_best.pth" and os.path.exists(fallback_ckpt): |
| print(f"Checkpoint not found: {args.ckpt}") |
| print(f"Falling back to: {fallback_ckpt}") |
| args.ckpt = fallback_ckpt |
| else: |
| raise FileNotFoundError(f"Checkpoint not found: {args.ckpt}") |
|
|
| ckpt = torch.load(args.ckpt, map_location="cpu") |
| ckpt_vocab = ckpt.get("vocab_size") |
| if ckpt_vocab is not None and int(ckpt_vocab) != vocab_size: |
| raise RuntimeError( |
| f"Checkpoint/tokenizer mismatch: ckpt vocab_size={ckpt_vocab}, tokenizer vocab_size={vocab_size}. " |
| "Use the matching tokenizer or checkpoint." |
| ) |
| cfg = config_from_dict(ckpt.get("model_config")) |
| model = GPT(vocab_size, cfg=cfg).to("cpu") |
| try: |
| model.load_state_dict(ckpt["model"], strict=True) |
| except Exception as exc: |
| raise RuntimeError( |
| "Checkpoint is incompatible with current model/tokenizer settings. " |
| "Use a matching checkpoint such as " |
| "'cpu_gpt_jarvis_rebuild_l6_v2048_best.pth'. " |
| f"Original error: {exc}" |
| ) from exc |
|
|
| model.eval() |
| print( |
| f"Loaded checkpoint: step={ckpt.get('step', 'n/a')} " |
| f"best_val={ckpt.get('best_val', 'n/a')}" |
| ) |
|
|
| if args.int8: |
| warnings.filterwarnings( |
| "ignore", |
| message="torch.ao.quantization is deprecated*", |
| category=DeprecationWarning, |
| ) |
| try: |
| model = torch.ao.quantization.quantize_dynamic( |
| model, |
| {torch.nn.Linear}, |
| dtype=torch.qint8, |
| ) |
| model.eval() |
| print("INT8 CHAT READY") |
| except Exception as exc: |
| print(f"INT8 quantization skipped: {exc}") |
| else: |
| print("FP32 CHAT READY") |
|
|
| model_block_size = int(getattr(model, "cfg", DEFAULT_CONFIG).block_size) |
| requested_ctx = int(args.max_context_tokens) if int(args.max_context_tokens) > 0 else model_block_size |
| max_ctx = max(32, min(requested_ctx, model_block_size)) |
| banned_token_ids = collect_banned_token_ids(tokenizer, args.ban_empty_tokens) |
| retrieval_bank = [] |
| retrieval_idf = {} |
| if args.use_retrieval: |
| refine_bank = [] |
| general_bank = [] |
| if os.path.exists(args.retrieval_file): |
| refine_bank = load_retrieval_bank(args.retrieval_file, args.retrieval_max_rows) |
| if os.path.exists(args.retrieval_file_general): |
| general_bank = load_retrieval_bank(args.retrieval_file_general, args.retrieval_max_rows) |
| retrieval_bank = merge_retrieval_banks(refine_bank, general_bank) |
| retrieval_idf = build_retrieval_idf(retrieval_bank) |
| print( |
| "Retrieval bank loaded: " |
| f"{len(retrieval_bank)} rows " |
| f"(refine={len(refine_bank)}, general={len(general_bank)})" |
| ) |
|
|
| bootstrap = "" |
| if args.system_prompt.strip(): |
| bootstrap = f"User: {args.system_prompt.strip()}\nAssistant: Understood.\n" |
| history_tokens = tokenizer.encode(bootstrap) |
|
|
| last_reply_signature = "" |
| print("\nType 'exit' to quit. Use '/reset' to clear chat history.\n") |
|
|
| while True: |
| user = input("\nUser: ").strip() |
| if user.lower() in {"exit", "quit"}: |
| break |
| if user.lower() == "/reset": |
| history_tokens = tokenizer.encode(bootstrap) |
| last_reply_signature = "" |
| print("\nAssistant: History cleared.") |
| continue |
| if not user: |
| continue |
|
|
| if args.safe_fallback: |
| blocked = unsafe_request_reply(user) |
| if blocked: |
| blocked = polish_reply(blocked) |
| print(f"\nAssistant: {blocked}") |
| history_tokens = (history_tokens + tokenizer.encode(f"\nUser: {user}\nAssistant: {blocked}"))[ |
| -max_ctx: |
| ] |
| last_reply_signature = canonical_reply(blocked) |
| continue |
|
|
| rule = None |
| heuristic = None |
| retrieved = None |
|
|
| if args.safe_fallback: |
| rule = safe_rule_reply(user) |
| if rule: |
| rule = finalize_reply( |
| user, |
| rule, |
| last_reply_signature, |
| args.safe_fallback, |
| allow_repeat=True, |
| ) |
| print(f"\nAssistant: {rule}") |
| history_tokens = (history_tokens + tokenizer.encode(f"\nUser: {user}\nAssistant: {rule}"))[ |
| -max_ctx: |
| ] |
| last_reply_signature = canonical_reply(rule) |
| continue |
|
|
| if args.use_retrieval: |
| retrieved_candidate = retrieve_reply(user, retrieval_bank, retrieval_idf) |
| if retrieved_candidate: |
| if len(retrieved_candidate) > 800: |
| retrieved_candidate = retrieved_candidate[:797].rstrip() + "..." |
| retrieved = retrieved_candidate |
|
|
| if args.safe_fallback: |
| heuristic = heuristic_answer(user) |
|
|
| if retrieved: |
| turn_prefix = f"\nContext: {retrieved}\nUser: {user}\nAssistant:" |
| else: |
| turn_prefix = f"\nUser: {user}\nAssistant:" |
|
|
| prompt_tokens = history_tokens + tokenizer.encode(turn_prefix) |
| prompt_tokens = prompt_tokens[-max_ctx:] |
|
|
| reply, generated_tokens = generate_best_of_n( |
| model=model, |
| tokenizer=tokenizer, |
| prompt_tokens=prompt_tokens, |
| max_new_tokens=args.max_new_tokens, |
| min_new_tokens=args.min_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| no_repeat_ngram=args.no_repeat_ngram, |
| max_context_tokens=max_ctx, |
| banned_token_ids=banned_token_ids, |
| num_candidates=args.num_candidates, |
| model_block_size=model_block_size, |
| ) |
|
|
| if (not reply or likely_gibberish(reply)) and args.safe_fallback: |
| if rule: |
| reply = rule |
| elif retrieved: |
| reply = retrieved |
| elif heuristic: |
| reply = heuristic |
| else: |
| reply = generic_fallback_reply(user) |
| generated_tokens = tokenizer.encode(reply) |
| elif not reply: |
| reply = "I need more context. Please restate your request in one sentence." |
| reply = finalize_reply(user, reply, last_reply_signature, args.safe_fallback) |
| generated_tokens = tokenizer.encode(reply) |
|
|
| print(f"\nAssistant: {reply}") |
|
|
| history_tokens = prompt_tokens + generated_tokens |
| history_tokens = history_tokens[-max_ctx:] |
| last_reply_signature = canonical_reply(reply) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|