scratch-ai / server.py
16dvnk's picture
Update server.py
725a466 verified
import os
import re
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import make_dataclass
from tokenizers import Tokenizer, decoders
from flask import Flask, request, jsonify
from flask_cors import CORS
# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
TOKEN_PATH = os.path.join(MODEL_DIR, "bpe-tokenizer.json")
# Update this to your latest best model name!
MODEL_PATH = os.path.join(MODEL_DIR, "gemma_3_1b_distill_qwen_2.5_0.5b_distill_aai_mini+_new_260322.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEQ_LENGTH = 256
# Updated to match the new 384-dimension architecture
MODEL_CONFIG_PARAMS = {
'MODEL_DIM': 384, 'NUM_HEADS': 6, 'NUM_LAYERS': 6, 'DROPOUT': 0.1, 'SEQ_LENGTH': SEQ_LENGTH
}
print(f"INFO: Using device: {DEVICE}")
# ==============================================================================
# 2. AI MODEL ARCHITECTURE
# ==============================================================================
class DecoderBlock(nn.Module):
def __init__(self, d_model, nhead, dropout):
super().__init__()
self.d_model, self.nhead = d_model, nhead
assert d_model % nhead == 0, "d_model must be divisible by nhead"
self.ln1, self.qkv_proj = nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3)
self.out_proj, self.dropout1 = nn.Linear(d_model, d_model), nn.Dropout(dropout)
self.ln2, self.ffwd = nn.LayerNorm(d_model), nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(),
nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
def forward(self, x):
B, T, D = x.shape; x_norm = self.ln1(x); q, k, v = self.qkv_proj(x_norm).chunk(3, dim=-1)
head_dim = D // self.nhead; q,k,v = (t.view(B, T, self.nhead, head_dim).transpose(1, 2) for t in (q,k,v))
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True).transpose(1, 2).contiguous().view(B, T, D)
x = x + self.dropout1(self.out_proj(attn_output)); x = x + self.ffwd(self.ln2(x)); return x
class GPTTransformerLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.VOCAB_SIZE, config.MODEL_DIM)
self.pos_emb = nn.Embedding(config.SEQ_LENGTH, config.MODEL_DIM)
self.dropout = nn.Dropout(config.DROPOUT)
self.layers = nn.ModuleList([DecoderBlock(config.MODEL_DIM, config.NUM_HEADS, config.DROPOUT) for _ in range(config.NUM_LAYERS)])
self.ln_f = nn.LayerNorm(config.MODEL_DIM)
self.head = nn.Linear(config.MODEL_DIM, config.VOCAB_SIZE, bias=False)
# Tie weights
self.tok_emb.weight = self.head.weight
self.register_buffer('pos', torch.arange(config.SEQ_LENGTH, dtype=torch.long))
def forward(self, x):
B, T = x.size()
h = self.tok_emb(x) + self.pos_emb(self.pos[:T])
h = self.dropout(h)
for layer in self.layers: h = layer(x=h)
return self.head(self.ln_f(h))
# ==============================================================================
# 3. NANO-REASONING GENERATION LOGIC
# ==============================================================================
@torch.inference_mode()
def generate_full_response(model, tokenizer, prompt, max_new_tokens=256, temp=0.45, top_k=50, top_p=0.9, repetition_penalty=1.15, nano_samples=25):
model.eval()
prompt_ids = tokenizer.encode(prompt).ids
gen_ids = prompt_ids[:]
x = torch.tensor(prompt_ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
# --- PROTECTED TOKENS ---
protected_tokens = set()
if getattr(tokenizer, 'token_to_id', None):
eos = tokenizer.token_to_id('[EOS]')
if eos is not None: protected_tokens.add(eos)
for char in ['\n', '\n\n', '.', ',', '!', '?', ' ', ' \n', '"', "'", ':', ';']:
protected_tokens.update(tokenizer.encode(char).ids)
bf16_supported = (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
autocast_dtype = torch.bfloat16 if bf16_supported else torch.float16
autocast_kwargs = {'device_type': DEVICE.type, 'dtype': autocast_dtype, 'enabled': (DEVICE.type == 'cuda')}
hard_limit = max_new_tokens + 100
for step in range(hard_limit):
x_cond = x if x.size(1) <= model.config.SEQ_LENGTH else x[:, -model.config.SEQ_LENGTH:]
with torch.autocast(**autocast_kwargs):
logits = model(x_cond)[:, -1, :]
# 1. DYNAMIC Frequency Penalty
if repetition_penalty != 1.0:
penalty_window = gen_ids[-128:]
for token_id in set(penalty_window):
if token_id not in protected_tokens:
count = penalty_window.count(token_id)
dynamic_penalty = repetition_penalty + (count * 0.03)
if logits[0, token_id] < 0:
logits[0, token_id] *= dynamic_penalty
else:
logits[0, token_id] /= dynamic_penalty
# 2. Temperature & Sampling
if temp > 0.0:
logits = logits / temp
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if top_p > 0.0 and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
probs = F.softmax(logits, dim=-1)
# --- NANO-LOOKAHEAD REASONING ---
simulations = torch.multinomial(probs, num_samples=nano_samples, replacement=True)
unique_ids, counts = torch.unique(simulations[0], return_counts=True)
if len(unique_ids) == 1:
next_id = unique_ids[0].item()
else:
sorted_indices = torch.argsort(counts, descending=True)
top_2_ids = unique_ids[sorted_indices[:2]]
best_score = -1.0
best_id = top_2_ids[0].item()
for candidate_id in top_2_ids:
candidate_prob = probs[0, candidate_id].item()
if x.size(1) < model.config.SEQ_LENGTH:
temp_x = torch.cat((x, torch.tensor([[candidate_id]], device=DEVICE)), dim=1)
temp_x_cond = temp_x if temp_x.size(1) <= model.config.SEQ_LENGTH else temp_x[:, -model.config.SEQ_LENGTH:]
with torch.autocast(**autocast_kwargs):
next_logits = model(temp_x_cond)[:, -1, :]
next_max_prob = torch.max(F.softmax(next_logits / temp, dim=-1)).item()
score = candidate_prob * next_max_prob
if score > best_score:
best_score = score
best_id = candidate_id.item()
next_id = best_id
# ---------------------------
else:
next_id = torch.argmax(logits, dim=-1).item()
gen_ids.append(next_id)
x = torch.cat((x, torch.tensor([[next_id]], device=DEVICE)), dim=1)
# --- STOP CONDITIONS ---
# Stop on double newline
tail_text = tokenizer.decode(gen_ids[-8:])
if "\n\n" in tail_text.replace(" ", "").replace("\r", ""):
break
# Stop on User turn hallucination
if "\nUser:" in tail_text or "User:" in tail_text.replace("\n", "")[-6:]:
break
# Stop on EOS token
if next_id == getattr(tokenizer, 'token_to_id', lambda x: -1)('[EOS]'):
break
# Graceful Sentence Cutoff
if step >= max_new_tokens:
new_text = tokenizer.decode([next_id])
if any(punct in new_text for punct in ['.', '!', '?', '\n']):
break
# 🚀 FIXED: Return the full, unbroken string (Prompt + Continuation)
return tokenizer.decode(gen_ids)
# ==============================================================================
# 4. LOAD MODEL & TOKENIZER ON STARTUP
# ==============================================================================
try:
print("INFO: Loading tokenizer...")
tokenizer = Tokenizer.from_file(TOKEN_PATH)
vocab_size = tokenizer.get_vocab_size()
MODEL_CONFIG_PARAMS['VOCAB_SIZE'] = vocab_size
print(f"INFO: Tokenizer loaded. Vocab size: {vocab_size}")
TuneConfig = make_dataclass("TuneConfig", [(k, type(v)) for k, v in MODEL_CONFIG_PARAMS.items()])
setattr(sys.modules['__main__'], 'TuneConfig', TuneConfig)
model_config_instance = TuneConfig(**MODEL_CONFIG_PARAMS)
model = GPTTransformerLM(model_config_instance).to(DEVICE)
print(f"INFO: Loading checkpoint dictionary from: {MODEL_PATH}")
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
state_dict = checkpoint
if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
cleaned_state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state_dict, strict=False)
model.eval()
print("✅ INFO: Backend Nano-Reasoning Model loaded and ready for API requests!")
except Exception as e:
print(f"❌ FATAL ERROR during model loading: {e}"); exit()
# ==============================================================================
# 5. FLASK SERVER APPLICATION
# ==============================================================================
app = Flask(__name__)
CORS(app)
@app.route("/generate", methods=["POST"])
def handle_generation():
data = request.get_json()
prompt = data.get("prompt")
if not prompt: return jsonify({"error": "Missing 'prompt'"}), 400
try:
print(f"INFO: Received prompt: '{prompt}'")
# Generate clean response using hidden Nano-Reasoning math
generated_text = generate_full_response(model, tokenizer, prompt)
print(f"INFO: Generated response: '{generated_text}'")
return jsonify({"response": generated_text})
except Exception as e:
import traceback
traceback.print_exc()
print(f"❌ ERROR during generation: {e}")
return jsonify({"error": "Failed to generate response"}), 500
@app.route("/")
def health_check():
return "Nano-Reasoning Backend server is running. Use the /generate endpoint to interact with the AI."