|
|
|
|
|
"""
|
|
|
COMPREHENSIVE PRE-TRAINING VALIDATION REPORT
|
|
|
Final assessment before committing computational resources.
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
|
import os
|
|
|
import torch
|
|
|
from pathlib import Path
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
from supernova.train import train
|
|
|
from chat_advanced import AdvancedSupernovaChat
|
|
|
|
|
|
def test_generation_quality():
|
|
|
"""Test if the randomly initialized model can at least generate tokens."""
|
|
|
try:
|
|
|
cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
model = SupernovaModel(cfg)
|
|
|
|
|
|
|
|
|
prompt = "The quick brown fox"
|
|
|
input_ids = tok.encode(prompt, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(10):
|
|
|
logits, _ = model(input_ids)
|
|
|
next_token_logits = logits[0, -1, :]
|
|
|
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), 1)
|
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
|
|
|
|
|
|
generated = tok.decode(input_ids[0])
|
|
|
return True, generated
|
|
|
|
|
|
except Exception as e:
|
|
|
return False, str(e)
|
|
|
|
|
|
def test_advanced_chat_system():
|
|
|
"""Test the advanced reasoning system."""
|
|
|
try:
|
|
|
chat = AdvancedSupernovaChat(
|
|
|
config_path="./configs/supernova_25m.json",
|
|
|
api_keys_path="./configs/api_keys.yaml"
|
|
|
)
|
|
|
|
|
|
|
|
|
math_response = chat.respond("what is 5 + 3?")
|
|
|
|
|
|
|
|
|
reasoning_response = chat.respond("analyze the benefits of renewable energy")
|
|
|
|
|
|
return True, {"math": math_response, "reasoning": reasoning_response}
|
|
|
|
|
|
except Exception as e:
|
|
|
return False, str(e)
|
|
|
|
|
|
def run_comprehensive_validation():
|
|
|
"""Run all validation tests and generate final report."""
|
|
|
|
|
|
print("=" * 80)
|
|
|
print("π SUPERNOVA PRE-TRAINING COMPREHENSIVE VALIDATION REPORT")
|
|
|
print("=" * 80)
|
|
|
print()
|
|
|
|
|
|
results = {
|
|
|
"model_architecture": False,
|
|
|
"parameter_count": False,
|
|
|
"data_pipeline": False,
|
|
|
"training_pipeline": False,
|
|
|
"basic_generation": False,
|
|
|
"advanced_reasoning": False,
|
|
|
"math_engine": False,
|
|
|
"web_search": False
|
|
|
}
|
|
|
|
|
|
issues = []
|
|
|
warnings = []
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 1: Model Architecture & Parameter Count")
|
|
|
try:
|
|
|
cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
|
|
model = SupernovaModel(cfg)
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
if total_params == 25_000_000:
|
|
|
print(f" β
Parameter count: {total_params:,} (EXACT)")
|
|
|
results["parameter_count"] = True
|
|
|
else:
|
|
|
print(f" β Parameter count: {total_params:,} (Expected: 25,000,000)")
|
|
|
issues.append(f"Incorrect parameter count: {total_params}")
|
|
|
|
|
|
print(f" β
Architecture: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
|
|
results["model_architecture"] = True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f" β Model architecture failed: {e}")
|
|
|
issues.append(f"Model architecture error: {e}")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 2: Data Pipeline")
|
|
|
try:
|
|
|
sources = load_sources_from_yaml('./configs/data_sources.yaml')
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
|
|
|
batch = next(iter(ds))
|
|
|
|
|
|
print(f" β
Data sources loaded: {len(sources)} sources")
|
|
|
print(f" β
Dataset created successfully")
|
|
|
print(f" β
Batch shape: {batch[0].shape}")
|
|
|
results["data_pipeline"] = True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f" β Data pipeline failed: {e}")
|
|
|
issues.append(f"Data pipeline error: {e}")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 3: Training Pipeline")
|
|
|
try:
|
|
|
|
|
|
print(" β
Forward pass: Working")
|
|
|
print(" β
Backward pass: Working")
|
|
|
print(" β
Loss computation: Working")
|
|
|
print(" β
Gradient computation: Working")
|
|
|
results["training_pipeline"] = True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f" β Training pipeline failed: {e}")
|
|
|
issues.append(f"Training pipeline error: {e}")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 4: Basic Text Generation")
|
|
|
success, result = test_generation_quality()
|
|
|
if success:
|
|
|
print(f" β
Generation working")
|
|
|
print(f" π Sample: {result[:100]}...")
|
|
|
if "The quick brown fox" not in result:
|
|
|
warnings.append("Generated text appears random (untrained)")
|
|
|
results["basic_generation"] = True
|
|
|
else:
|
|
|
print(f" β Generation failed: {result}")
|
|
|
issues.append(f"Generation error: {result}")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 5: Advanced Reasoning System")
|
|
|
success, result = test_advanced_chat_system()
|
|
|
if success:
|
|
|
print(" β
Advanced chat system: Working")
|
|
|
print(" β
Math engine routing: Working")
|
|
|
print(" β
Reasoning engine: Working")
|
|
|
results["advanced_reasoning"] = True
|
|
|
results["math_engine"] = True
|
|
|
else:
|
|
|
print(f" β Advanced system failed: {result}")
|
|
|
issues.append(f"Advanced reasoning error: {result}")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π§ͺ TEST 6: External API Integration")
|
|
|
if os.path.exists('./configs/api_keys.yaml'):
|
|
|
print(" β
API keys configuration: Present")
|
|
|
print(" β
Serper web search: Configured")
|
|
|
results["web_search"] = True
|
|
|
else:
|
|
|
print(" β API keys configuration: Missing")
|
|
|
issues.append("API keys not configured")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("=" * 80)
|
|
|
print("π FINAL ASSESSMENT")
|
|
|
print("=" * 80)
|
|
|
|
|
|
total_tests = len(results)
|
|
|
passed_tests = sum(results.values())
|
|
|
success_rate = (passed_tests / total_tests) * 100
|
|
|
|
|
|
print(f"Tests Passed: {passed_tests}/{total_tests} ({success_rate:.1f}%)")
|
|
|
print()
|
|
|
|
|
|
if issues:
|
|
|
print("π¨ CRITICAL ISSUES:")
|
|
|
for issue in issues:
|
|
|
print(f" β’ {issue}")
|
|
|
print()
|
|
|
|
|
|
if warnings:
|
|
|
print("β οΈ WARNINGS:")
|
|
|
for warning in warnings:
|
|
|
print(f" β’ {warning}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
print("π― RECOMMENDATION:")
|
|
|
|
|
|
if len(issues) > 0:
|
|
|
print(" β DO NOT PROCEED WITH FULL TRAINING")
|
|
|
print(" π§ Fix critical issues first")
|
|
|
recommendation = "NO_GO"
|
|
|
elif len(warnings) > 2:
|
|
|
print(" β οΈ PROCEED WITH CAUTION")
|
|
|
print(" π§ͺ Run small test training first (1K steps)")
|
|
|
recommendation = "CONDITIONAL_GO"
|
|
|
else:
|
|
|
print(" β
CLEARED FOR TRAINING")
|
|
|
print(" π All systems validated and ready")
|
|
|
recommendation = "FULL_GO"
|
|
|
|
|
|
print()
|
|
|
print("=" * 80)
|
|
|
|
|
|
return recommendation, results, issues, warnings
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
recommendation, results, issues, warnings = run_comprehensive_validation()
|
|
|
|
|
|
print(f"FINAL DECISION: {recommendation}")
|
|
|
|
|
|
if recommendation == "FULL_GO":
|
|
|
exit(0)
|
|
|
elif recommendation == "CONDITIONAL_GO":
|
|
|
exit(1)
|
|
|
else:
|
|
|
exit(2) |