Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import make_dataclass | |
| from tokenizers import Tokenizer, decoders | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| # ============================================================================== | |
| # 1. CONFIGURATION | |
| # ============================================================================== | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "saved_models") | |
| TOKEN_PATH = os.path.join(MODEL_DIR, "bpe-tokenizer.json") | |
| # Update this to your latest best model name! | |
| MODEL_PATH = os.path.join(MODEL_DIR, "gemma_3_1b_distill_qwen_2.5_0.5b_distill_aai_mini+_new_260322.pth") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| SEQ_LENGTH = 256 | |
| # Updated to match the new 384-dimension architecture | |
| MODEL_CONFIG_PARAMS = { | |
| 'MODEL_DIM': 384, 'NUM_HEADS': 6, 'NUM_LAYERS': 6, 'DROPOUT': 0.1, 'SEQ_LENGTH': SEQ_LENGTH | |
| } | |
| print(f"INFO: Using device: {DEVICE}") | |
| # ============================================================================== | |
| # 2. AI MODEL ARCHITECTURE | |
| # ============================================================================== | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, d_model, nhead, dropout): | |
| super().__init__() | |
| self.d_model, self.nhead = d_model, nhead | |
| assert d_model % nhead == 0, "d_model must be divisible by nhead" | |
| self.ln1, self.qkv_proj = nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3) | |
| self.out_proj, self.dropout1 = nn.Linear(d_model, d_model), nn.Dropout(dropout) | |
| self.ln2, self.ffwd = nn.LayerNorm(d_model), nn.Sequential( | |
| nn.Linear(d_model, d_model * 4), nn.GELU(), | |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) | |
| def forward(self, x): | |
| B, T, D = x.shape; x_norm = self.ln1(x); q, k, v = self.qkv_proj(x_norm).chunk(3, dim=-1) | |
| head_dim = D // self.nhead; q,k,v = (t.view(B, T, self.nhead, head_dim).transpose(1, 2) for t in (q,k,v)) | |
| attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2).contiguous().view(B, T, D) | |
| x = x + self.dropout1(self.out_proj(attn_output)); x = x + self.ffwd(self.ln2(x)); return x | |
| class GPTTransformerLM(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.tok_emb = nn.Embedding(config.VOCAB_SIZE, config.MODEL_DIM) | |
| self.pos_emb = nn.Embedding(config.SEQ_LENGTH, config.MODEL_DIM) | |
| self.dropout = nn.Dropout(config.DROPOUT) | |
| self.layers = nn.ModuleList([DecoderBlock(config.MODEL_DIM, config.NUM_HEADS, config.DROPOUT) for _ in range(config.NUM_LAYERS)]) | |
| self.ln_f = nn.LayerNorm(config.MODEL_DIM) | |
| self.head = nn.Linear(config.MODEL_DIM, config.VOCAB_SIZE, bias=False) | |
| # Tie weights | |
| self.tok_emb.weight = self.head.weight | |
| self.register_buffer('pos', torch.arange(config.SEQ_LENGTH, dtype=torch.long)) | |
| def forward(self, x): | |
| B, T = x.size() | |
| h = self.tok_emb(x) + self.pos_emb(self.pos[:T]) | |
| h = self.dropout(h) | |
| for layer in self.layers: h = layer(x=h) | |
| return self.head(self.ln_f(h)) | |
| # ============================================================================== | |
| # 3. NANO-REASONING GENERATION LOGIC | |
| # ============================================================================== | |
| def generate_full_response(model, tokenizer, prompt, max_new_tokens=256, temp=0.45, top_k=50, top_p=0.9, repetition_penalty=1.15, nano_samples=25): | |
| model.eval() | |
| prompt_ids = tokenizer.encode(prompt).ids | |
| gen_ids = prompt_ids[:] | |
| x = torch.tensor(prompt_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) | |
| # --- PROTECTED TOKENS --- | |
| protected_tokens = set() | |
| if getattr(tokenizer, 'token_to_id', None): | |
| eos = tokenizer.token_to_id('[EOS]') | |
| if eos is not None: protected_tokens.add(eos) | |
| for char in ['\n', '\n\n', '.', ',', '!', '?', ' ', ' \n', '"', "'", ':', ';']: | |
| protected_tokens.update(tokenizer.encode(char).ids) | |
| bf16_supported = (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) | |
| autocast_dtype = torch.bfloat16 if bf16_supported else torch.float16 | |
| autocast_kwargs = {'device_type': DEVICE.type, 'dtype': autocast_dtype, 'enabled': (DEVICE.type == 'cuda')} | |
| hard_limit = max_new_tokens + 100 | |
| for step in range(hard_limit): | |
| x_cond = x if x.size(1) <= model.config.SEQ_LENGTH else x[:, -model.config.SEQ_LENGTH:] | |
| with torch.autocast(**autocast_kwargs): | |
| logits = model(x_cond)[:, -1, :] | |
| # 1. DYNAMIC Frequency Penalty | |
| if repetition_penalty != 1.0: | |
| penalty_window = gen_ids[-128:] | |
| for token_id in set(penalty_window): | |
| if token_id not in protected_tokens: | |
| count = penalty_window.count(token_id) | |
| dynamic_penalty = repetition_penalty + (count * 0.03) | |
| if logits[0, token_id] < 0: | |
| logits[0, token_id] *= dynamic_penalty | |
| else: | |
| logits[0, token_id] /= dynamic_penalty | |
| # 2. Temperature & Sampling | |
| if temp > 0.0: | |
| logits = logits / temp | |
| if top_k > 0: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| if top_p > 0.0 and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| logits[indices_to_remove] = -float('Inf') | |
| probs = F.softmax(logits, dim=-1) | |
| # --- NANO-LOOKAHEAD REASONING --- | |
| simulations = torch.multinomial(probs, num_samples=nano_samples, replacement=True) | |
| unique_ids, counts = torch.unique(simulations[0], return_counts=True) | |
| if len(unique_ids) == 1: | |
| next_id = unique_ids[0].item() | |
| else: | |
| sorted_indices = torch.argsort(counts, descending=True) | |
| top_2_ids = unique_ids[sorted_indices[:2]] | |
| best_score = -1.0 | |
| best_id = top_2_ids[0].item() | |
| for candidate_id in top_2_ids: | |
| candidate_prob = probs[0, candidate_id].item() | |
| if x.size(1) < model.config.SEQ_LENGTH: | |
| temp_x = torch.cat((x, torch.tensor([[candidate_id]], device=DEVICE)), dim=1) | |
| temp_x_cond = temp_x if temp_x.size(1) <= model.config.SEQ_LENGTH else temp_x[:, -model.config.SEQ_LENGTH:] | |
| with torch.autocast(**autocast_kwargs): | |
| next_logits = model(temp_x_cond)[:, -1, :] | |
| next_max_prob = torch.max(F.softmax(next_logits / temp, dim=-1)).item() | |
| score = candidate_prob * next_max_prob | |
| if score > best_score: | |
| best_score = score | |
| best_id = candidate_id.item() | |
| next_id = best_id | |
| # --------------------------- | |
| else: | |
| next_id = torch.argmax(logits, dim=-1).item() | |
| gen_ids.append(next_id) | |
| x = torch.cat((x, torch.tensor([[next_id]], device=DEVICE)), dim=1) | |
| # --- STOP CONDITIONS --- | |
| # Stop on double newline | |
| tail_text = tokenizer.decode(gen_ids[-8:]) | |
| if "\n\n" in tail_text.replace(" ", "").replace("\r", ""): | |
| break | |
| # Stop on User turn hallucination | |
| if "\nUser:" in tail_text or "User:" in tail_text.replace("\n", "")[-6:]: | |
| break | |
| # Stop on EOS token | |
| if next_id == getattr(tokenizer, 'token_to_id', lambda x: -1)('[EOS]'): | |
| break | |
| # Graceful Sentence Cutoff | |
| if step >= max_new_tokens: | |
| new_text = tokenizer.decode([next_id]) | |
| if any(punct in new_text for punct in ['.', '!', '?', '\n']): | |
| break | |
| # 🚀 FIXED: Return the full, unbroken string (Prompt + Continuation) | |
| return tokenizer.decode(gen_ids) | |
| # ============================================================================== | |
| # 4. LOAD MODEL & TOKENIZER ON STARTUP | |
| # ============================================================================== | |
| try: | |
| print("INFO: Loading tokenizer...") | |
| tokenizer = Tokenizer.from_file(TOKEN_PATH) | |
| vocab_size = tokenizer.get_vocab_size() | |
| MODEL_CONFIG_PARAMS['VOCAB_SIZE'] = vocab_size | |
| print(f"INFO: Tokenizer loaded. Vocab size: {vocab_size}") | |
| TuneConfig = make_dataclass("TuneConfig", [(k, type(v)) for k, v in MODEL_CONFIG_PARAMS.items()]) | |
| setattr(sys.modules['__main__'], 'TuneConfig', TuneConfig) | |
| model_config_instance = TuneConfig(**MODEL_CONFIG_PARAMS) | |
| model = GPTTransformerLM(model_config_instance).to(DEVICE) | |
| print(f"INFO: Loading checkpoint dictionary from: {MODEL_PATH}") | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| state_dict = checkpoint | |
| if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] | |
| elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] | |
| cleaned_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(cleaned_state_dict, strict=False) | |
| model.eval() | |
| print("✅ INFO: Backend Nano-Reasoning Model loaded and ready for API requests!") | |
| except Exception as e: | |
| print(f"❌ FATAL ERROR during model loading: {e}"); exit() | |
| # ============================================================================== | |
| # 5. FLASK SERVER APPLICATION | |
| # ============================================================================== | |
| app = Flask(__name__) | |
| CORS(app) | |
| def handle_generation(): | |
| data = request.get_json() | |
| prompt = data.get("prompt") | |
| if not prompt: return jsonify({"error": "Missing 'prompt'"}), 400 | |
| try: | |
| print(f"INFO: Received prompt: '{prompt}'") | |
| # Generate clean response using hidden Nano-Reasoning math | |
| generated_text = generate_full_response(model, tokenizer, prompt) | |
| print(f"INFO: Generated response: '{generated_text}'") | |
| return jsonify({"response": generated_text}) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"❌ ERROR during generation: {e}") | |
| return jsonify({"error": "Failed to generate response"}), 500 | |
| def health_check(): | |
| return "Nano-Reasoning Backend server is running. Use the /generate endpoint to interact with the AI." |