import os import re import sys import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import make_dataclass from tokenizers import Tokenizer, decoders from flask import Flask, request, jsonify from flask_cors import CORS # ============================================================================== # 1. CONFIGURATION # ============================================================================== BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_DIR = os.path.join(BASE_DIR, "saved_models") TOKEN_PATH = os.path.join(MODEL_DIR, "bpe-tokenizer.json") # Update this to your latest best model name! MODEL_PATH = os.path.join(MODEL_DIR, "gemma_3_1b_distill_qwen_2.5_0.5b_distill_aai_mini+_new_260322.pth") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SEQ_LENGTH = 256 # Updated to match the new 384-dimension architecture MODEL_CONFIG_PARAMS = { 'MODEL_DIM': 384, 'NUM_HEADS': 6, 'NUM_LAYERS': 6, 'DROPOUT': 0.1, 'SEQ_LENGTH': SEQ_LENGTH } print(f"INFO: Using device: {DEVICE}") # ============================================================================== # 2. AI MODEL ARCHITECTURE # ============================================================================== class DecoderBlock(nn.Module): def __init__(self, d_model, nhead, dropout): super().__init__() self.d_model, self.nhead = d_model, nhead assert d_model % nhead == 0, "d_model must be divisible by nhead" self.ln1, self.qkv_proj = nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3) self.out_proj, self.dropout1 = nn.Linear(d_model, d_model), nn.Dropout(dropout) self.ln2, self.ffwd = nn.LayerNorm(d_model), nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) def forward(self, x): B, T, D = x.shape; x_norm = self.ln1(x); q, k, v = self.qkv_proj(x_norm).chunk(3, dim=-1) head_dim = D // self.nhead; q,k,v = (t.view(B, T, self.nhead, head_dim).transpose(1, 2) for t in (q,k,v)) attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2).contiguous().view(B, T, D) x = x + self.dropout1(self.out_proj(attn_output)); x = x + self.ffwd(self.ln2(x)); return x class GPTTransformerLM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.VOCAB_SIZE, config.MODEL_DIM) self.pos_emb = nn.Embedding(config.SEQ_LENGTH, config.MODEL_DIM) self.dropout = nn.Dropout(config.DROPOUT) self.layers = nn.ModuleList([DecoderBlock(config.MODEL_DIM, config.NUM_HEADS, config.DROPOUT) for _ in range(config.NUM_LAYERS)]) self.ln_f = nn.LayerNorm(config.MODEL_DIM) self.head = nn.Linear(config.MODEL_DIM, config.VOCAB_SIZE, bias=False) # Tie weights self.tok_emb.weight = self.head.weight self.register_buffer('pos', torch.arange(config.SEQ_LENGTH, dtype=torch.long)) def forward(self, x): B, T = x.size() h = self.tok_emb(x) + self.pos_emb(self.pos[:T]) h = self.dropout(h) for layer in self.layers: h = layer(x=h) return self.head(self.ln_f(h)) # ============================================================================== # 3. NANO-REASONING GENERATION LOGIC # ============================================================================== @torch.inference_mode() def generate_full_response(model, tokenizer, prompt, max_new_tokens=256, temp=0.45, top_k=50, top_p=0.9, repetition_penalty=1.15, nano_samples=25): model.eval() prompt_ids = tokenizer.encode(prompt).ids gen_ids = prompt_ids[:] x = torch.tensor(prompt_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) # --- PROTECTED TOKENS --- protected_tokens = set() if getattr(tokenizer, 'token_to_id', None): eos = tokenizer.token_to_id('[EOS]') if eos is not None: protected_tokens.add(eos) for char in ['\n', '\n\n', '.', ',', '!', '?', ' ', ' \n', '"', "'", ':', ';']: protected_tokens.update(tokenizer.encode(char).ids) bf16_supported = (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) autocast_dtype = torch.bfloat16 if bf16_supported else torch.float16 autocast_kwargs = {'device_type': DEVICE.type, 'dtype': autocast_dtype, 'enabled': (DEVICE.type == 'cuda')} hard_limit = max_new_tokens + 100 for step in range(hard_limit): x_cond = x if x.size(1) <= model.config.SEQ_LENGTH else x[:, -model.config.SEQ_LENGTH:] with torch.autocast(**autocast_kwargs): logits = model(x_cond)[:, -1, :] # 1. DYNAMIC Frequency Penalty if repetition_penalty != 1.0: penalty_window = gen_ids[-128:] for token_id in set(penalty_window): if token_id not in protected_tokens: count = penalty_window.count(token_id) dynamic_penalty = repetition_penalty + (count * 0.03) if logits[0, token_id] < 0: logits[0, token_id] *= dynamic_penalty else: logits[0, token_id] /= dynamic_penalty # 2. Temperature & Sampling if temp > 0.0: logits = logits / temp if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') if top_p > 0.0 and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float('Inf') probs = F.softmax(logits, dim=-1) # --- NANO-LOOKAHEAD REASONING --- simulations = torch.multinomial(probs, num_samples=nano_samples, replacement=True) unique_ids, counts = torch.unique(simulations[0], return_counts=True) if len(unique_ids) == 1: next_id = unique_ids[0].item() else: sorted_indices = torch.argsort(counts, descending=True) top_2_ids = unique_ids[sorted_indices[:2]] best_score = -1.0 best_id = top_2_ids[0].item() for candidate_id in top_2_ids: candidate_prob = probs[0, candidate_id].item() if x.size(1) < model.config.SEQ_LENGTH: temp_x = torch.cat((x, torch.tensor([[candidate_id]], device=DEVICE)), dim=1) temp_x_cond = temp_x if temp_x.size(1) <= model.config.SEQ_LENGTH else temp_x[:, -model.config.SEQ_LENGTH:] with torch.autocast(**autocast_kwargs): next_logits = model(temp_x_cond)[:, -1, :] next_max_prob = torch.max(F.softmax(next_logits / temp, dim=-1)).item() score = candidate_prob * next_max_prob if score > best_score: best_score = score best_id = candidate_id.item() next_id = best_id # --------------------------- else: next_id = torch.argmax(logits, dim=-1).item() gen_ids.append(next_id) x = torch.cat((x, torch.tensor([[next_id]], device=DEVICE)), dim=1) # --- STOP CONDITIONS --- # Stop on double newline tail_text = tokenizer.decode(gen_ids[-8:]) if "\n\n" in tail_text.replace(" ", "").replace("\r", ""): break # Stop on User turn hallucination if "\nUser:" in tail_text or "User:" in tail_text.replace("\n", "")[-6:]: break # Stop on EOS token if next_id == getattr(tokenizer, 'token_to_id', lambda x: -1)('[EOS]'): break # Graceful Sentence Cutoff if step >= max_new_tokens: new_text = tokenizer.decode([next_id]) if any(punct in new_text for punct in ['.', '!', '?', '\n']): break # 🚀 FIXED: Return the full, unbroken string (Prompt + Continuation) return tokenizer.decode(gen_ids) # ============================================================================== # 4. LOAD MODEL & TOKENIZER ON STARTUP # ============================================================================== try: print("INFO: Loading tokenizer...") tokenizer = Tokenizer.from_file(TOKEN_PATH) vocab_size = tokenizer.get_vocab_size() MODEL_CONFIG_PARAMS['VOCAB_SIZE'] = vocab_size print(f"INFO: Tokenizer loaded. Vocab size: {vocab_size}") TuneConfig = make_dataclass("TuneConfig", [(k, type(v)) for k, v in MODEL_CONFIG_PARAMS.items()]) setattr(sys.modules['__main__'], 'TuneConfig', TuneConfig) model_config_instance = TuneConfig(**MODEL_CONFIG_PARAMS) model = GPTTransformerLM(model_config_instance).to(DEVICE) print(f"INFO: Loading checkpoint dictionary from: {MODEL_PATH}") checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) state_dict = checkpoint if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] cleaned_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} model.load_state_dict(cleaned_state_dict, strict=False) model.eval() print("✅ INFO: Backend Nano-Reasoning Model loaded and ready for API requests!") except Exception as e: print(f"❌ FATAL ERROR during model loading: {e}"); exit() # ============================================================================== # 5. FLASK SERVER APPLICATION # ============================================================================== app = Flask(__name__) CORS(app) @app.route("/generate", methods=["POST"]) def handle_generation(): data = request.get_json() prompt = data.get("prompt") if not prompt: return jsonify({"error": "Missing 'prompt'"}), 400 try: print(f"INFO: Received prompt: '{prompt}'") # Generate clean response using hidden Nano-Reasoning math generated_text = generate_full_response(model, tokenizer, prompt) print(f"INFO: Generated response: '{generated_text}'") return jsonify({"response": generated_text}) except Exception as e: import traceback traceback.print_exc() print(f"❌ ERROR during generation: {e}") return jsonify({"error": "Failed to generate response"}), 500 @app.route("/") def health_check(): return "Nano-Reasoning Backend server is running. Use the /generate endpoint to interact with the AI."