File size: 29,850 Bytes
6ac39a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 |
"""
Comprehensive test script for Wildnerve TLM that tests both model functionality and weight loading.
Usage:
# Test model inference with custom prompt
python test_model.py --prompt "Your test prompt here"
# Test the weights and maths
python test_model.py --check-weights --check-math --diagnostics
# Test to verify repos and list weights
python test_model.py --verify-repos --list-weights
# Test everything
python test_model.py --all
# Test just the weight loading
python test_model.py --check-weights
# Check repository access and list available weights
python test_model.py --verify-repos --list-weights
# Test model inference with custom prompt
python test_model.py --prompt "What is quantum computing?"
"""
import os
import sys
import time
import logging
import argparse
import importlib.util
from typing import Dict, Any, Optional, List, Tuple
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def test_model_loading(prompt: str, verbose: bool = False) -> Dict[str, Any]:
"""
Test if the model loads correctly and can generate responses.
Args:
prompt: Text prompt to test with
verbose: Whether to print detailed diagnostics
Returns:
Dictionary with test results
"""
results = {
"success": False,
"model_loaded": False,
"response": None,
"response_type": None,
"elapsed_time": 0,
"error": None
}
try:
# Import adapter layer
from adapter_layer import WildnerveModelAdapter
adapter = WildnerveModelAdapter("")
logger.info("Model adapter initialized")
# Record start time for performance measurement
start_time = time.time()
# Try to generate a response
logger.info(f"Generating response for: {prompt}")
response = adapter.generate(prompt)
# Record elapsed time
elapsed_time = time.time() - start_time
results["elapsed_time"] = elapsed_time
results["response"] = response
# Check if we got a non-fallback response
fallback_phrases = [
"I've received your input about",
"Processing:",
"The model couldn't be properly initialized",
"No language model available"
]
is_fallback = any(phrase in response for phrase in fallback_phrases)
results["response_type"] = "fallback" if is_fallback else "model"
results["model_loaded"] = not is_fallback
results["success"] = True
if verbose:
logger.info(f"Response ({len(response)} chars): {response[:100]}...")
logger.info(f"Response appears to be from: {'fallback' if is_fallback else 'neural model'}")
logger.info(f"Generation took: {elapsed_time:.2f} seconds")
return results
except Exception as e:
logger.error(f"Error testing model: {e}", exc_info=True)
results["error"] = str(e)
return results
def test_math_capability() -> Dict[str, Any]:
"""
Test the model's math capabilities with various arithmetic expressions.
Returns:
Dictionary with test results
"""
results = {
"success": False,
"tests_passed": 0,
"tests_total": 0,
"details": []
}
# Test cases: (prompt, expected_contains)
math_tests = [
("What is 3 + 4?", "7"),
("What is 12 * 5?", "60"),
("Calculate 18 / 6", "3"),
("What is four multiplied by three?", "12"),
("What is seven plus nine?", "16"),
("Compute 25 - 13", "12")
]
try:
from adapter_layer import WildnerveModelAdapter
adapter = WildnerveModelAdapter("")
logger.info("Testing math capabilities...")
results["tests_total"] = len(math_tests)
for i, (prompt, expected) in enumerate(math_tests):
logger.info(f"Math test {i+1}/{len(math_tests)}: {prompt}")
try:
response = adapter.generate(prompt)
passes = expected in response
results["details"].append({
"prompt": prompt,
"response": response,
"expected": expected,
"passed": passes
})
if passes:
results["tests_passed"] += 1
logger.info(f"β Test passed: found '{expected}' in response")
else:
logger.info(f"β Test failed: '{expected}' not found in response")
logger.info(f"Response: {response[:100]}...")
except Exception as e:
logger.error(f"Error in math test: {e}")
results["details"].append({
"prompt": prompt,
"error": str(e),
"passed": False
})
results["success"] = True
return results
except Exception as e:
logger.error(f"Failed to run math tests: {e}")
results["error"] = str(e)
return results
def test_weight_loading() -> Dict[str, Any]:
"""Test loading model weights from local files or HF repository.
Returns:
Dictionary with test results
"""
results = {
"success": False,
"local_weights_found": False,
"downloaded_weights": False,
"weight_files": {},
"errors": [],
"elapsed_time": 0
}
try:
start_time = time.time()
# Try to import load_model_weights
try:
from load_model_weights import load_model_weights, check_for_local_weights, verify_token
# First check token
token_verified = verify_token()
results["token_verified"] = token_verified
# Check for local weights
local_weights = check_for_local_weights()
results["local_weights_found"] = local_weights
if local_weights:
results["weight_files"] = {
"transformer": os.environ.get("TLM_TRANSFORMER_WEIGHTS"),
"snn": os.environ.get("TLM_SNN_WEIGHTS")
}
logger.info("Found local weights")
else:
# Try downloading weights
logger.info("No local weights found, downloading from HF Hub...")
weight_files = load_model_weights()
if weight_files:
results["downloaded_weights"] = True
results["weight_files"] = weight_files
logger.info(f"Downloaded weights: {list(weight_files.keys())}")
else:
logger.warning("Failed to download weights")
results["errors"].append("Failed to download weights")
except ImportError as e:
logger.error(f"Could not import load_model_weights: {e}")
results["errors"].append(f"ImportError: {str(e)}")
# Check if we got any weights
if results["local_weights_found"] or results["downloaded_weights"]:
results["success"] = True
# Record elapsed time
results["elapsed_time"] = time.time() - start_time
return results
except Exception as e:
logger.error(f"Error testing weight loading: {e}", exc_info=True)
results["errors"].append(str(e))
results["elapsed_time"] = time.time() - start_time
return results
def verify_repositories() -> Dict[str, Any]:
"""Verify access to model repositories.
Returns:
Dictionary with verification results
"""
results = {
"repositories_checked": 0,
"repositories_accessible": 0,
"details": {}
}
try:
# Try to import verification function
from load_model_weights import verify_repository, verify_token
# Get token
token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN"))
token_verified = verify_token()
results["token_verified"] = token_verified
# First try to get repositories from model_repo_config
try:
from model_repo_config import get_repo_config
config = get_repo_config()
repos_to_check = [config.repo_id] + config.alternative_paths
except ImportError:
# Fallback repositories
repos_to_check = [
"EvolphTech/Weights",
"Wildnerve/tlm-0.05Bx12",
"Wildnerve/tlm",
"EvolphTech/Checkpoints",
"bert-base-uncased" # Fallback public model
]
# Check each repository
for repo in repos_to_check:
logger.info(f"Verifying repository: {repo}")
success, files = verify_repository(repo, token)
results["repositories_checked"] += 1
if success:
results["repositories_accessible"] += 1
results["details"][repo] = {
"accessible": success,
"num_files": len(files) if success else 0,
"model_files": [f for f in files if f.endswith('.bin') or f.endswith('.pt')] if success else []
}
return results
except Exception as e:
logger.error(f"Error verifying repositories: {e}", exc_info=True)
results["error"] = str(e)
return results
def list_weight_files() -> Dict[str, Any]:
"""List available weight files in repositories and locally.
Returns:
Dictionary with weight file lists
"""
results = {
"local_weights": [],
"repository_weights": {},
"error": None
}
try:
# Check local weight files
local_paths = [
"/app/Weights/",
"./Weights/",
"/tmp/hf_cache/",
"/tmp/tlm_cache/"
]
for base_path in local_paths:
if os.path.exists(base_path):
for root, _, files in os.walk(base_path):
for file in files:
if file.endswith(('.bin', '.pt', '.pth')):
full_path = os.path.join(root, file)
relative_path = os.path.relpath(full_path, base_path)
results["local_weights"].append({
"path": full_path,
"relative_path": relative_path,
"size_mb": os.path.getsize(full_path) / (1024 * 1024)
})
# List repository weight files
try:
from load_model_weights import list_model_files, verify_token
# Get token
token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN"))
token_verified = verify_token()
# First try to get repositories from model_repo_config
try:
from model_repo_config import get_repo_config
config = get_repo_config()
repos_to_check = [config.repo_id] + config.alternative_paths[:2] # Only check first few
except ImportError:
# Fallback repositories
repos_to_check = ["EvolphTech/Weights", "Wildnerve/tlm-0.05Bx12"]
# Check each repository
for repo in repos_to_check:
try:
logger.info(f"Listing files in repository: {repo}")
files = list_model_files(repo, token)
results["repository_weights"][repo] = files
except Exception as e:
logger.warning(f"Error listing files in {repo}: {e}")
results["repository_weights"][repo] = f"Error: {str(e)}"
except ImportError as e:
results["error"] = f"Could not import functions to list repository files: {e}"
return results
except Exception as e:
logger.error(f"Error listing weight files: {e}", exc_info=True)
results["error"] = str(e)
return results
def test_weight_loading_in_model() -> Dict[str, Any]:
"""Test loading weights into an actual model instance.
Returns:
Dictionary with test results
"""
results = {
"success": False,
"model_created": False,
"weights_loaded": False,
"weight_path": None,
"error": None
}
try:
# Try to find or download weights
weight_loading_results = test_weight_loading()
if not (weight_loading_results["local_weights_found"] or weight_loading_results["downloaded_weights"]):
results["error"] = "No weights available to test"
return results
# Get weight path
weight_path = None
if "transformer" in weight_loading_results.get("weight_files", {}):
weight_path = weight_loading_results["weight_files"]["transformer"]
if not weight_path or not os.path.exists(weight_path):
results["error"] = f"Weight file not found at {weight_path}"
return results
results["weight_path"] = weight_path
# Try to create a model
try:
# Try model_Custm first
try:
import model_Custm
if hasattr(model_Custm, "Wildnerve_tlm01"):
logger.info("Creating Wildnerve_tlm01 from model_Custm")
model_class = getattr(model_Custm, "Wildnerve_tlm01")
model = model_class(
vocab_size=50257, # GPT-2 vocab size
specialization="general",
embedding_dim=768,
num_heads=12,
hidden_dim=768,
num_layers=2,
output_size=50257,
dropout=0.1,
max_seq_length=128
)
results["model_created"] = True
except Exception as e:
logger.warning(f"Error creating model_Custm: {e}")
# Try model_PrTr as fallback
try:
import model_PrTr
if hasattr(model_PrTr, "Wildnerve_tlm01"):
logger.info("Creating Wildnerve_tlm01 from model_PrTr")
model_class = getattr(model_PrTr, "Wildnerve_tlm01")
model = model_class(
model_name="gpt2"
)
results["model_created"] = True
except Exception as e2:
logger.error(f"Error creating model_PrTr: {e2}")
results["error"] = f"Could not create any model: {e}, {e2}"
return results
# Load weights into model
if results["model_created"]:
from load_model_weights import load_weights_into_model
success = load_weights_into_model(model, weight_path, strict=False)
results["weights_loaded"] = success
if success:
# Try a quick test inference
try:
test_input = "This is a test."
if hasattr(model, "generate"):
output = model.generate(prompt=test_input, max_length=20)
logger.info(f"Test inference output: {output}")
results["test_inference"] = output
results["success"] = True
except Exception as inf_err:
logger.warning(f"Test inference failed: {inf_err}")
# Still mark success if weights loaded
results["success"] = True
else:
results["error"] = "Failed to load weights into model"
except ImportError as e:
results["error"] = f"ImportError: {str(e)}"
return results
except Exception as e:
logger.error(f"Error testing weight loading in model: {e}", exc_info=True)
results["error"] = str(e)
return results
def run_diagnostics() -> Dict[str, Any]:
"""
Run diagnostics on the model environment and dependencies.
Returns:
Dictionary with diagnostic results
"""
diagnostics = {
"python_version": sys.version,
"environment": {},
"modules": {},
"gpu_available": False,
"files_present": {},
"model_repo_config": None
}
# Check environment variables
for var in ["MODEL_REPO", "HF_TOKEN", "TLM_TRANSFORMER_WEIGHTS", "TLM_SNN_WEIGHTS",
"LOW_MEMORY_MODE", "CUDA_VISIBLE_DEVICES"]:
diagnostics["environment"][var] = os.environ.get(var, "Not set")
# Check critical modules
for module_name in ["torch", "transformers", "adapter_layer", "model_Custm", "model_PrTr",
"load_model_weights", "model_repo_config"]:
try:
module_spec = importlib.util.find_spec(module_name)
if module_spec is not None:
try:
module = importlib.import_module(module_name)
diagnostics["modules"][module_name] = getattr(module, "__version__", "Available (no version)")
except Exception as e:
diagnostics["modules"][module_name] = f"Import error: {e}"
else:
diagnostics["modules"][module_name] = "Not found"
except ImportError:
diagnostics["modules"][module_name] = "Not available"
# Check for GPU
try:
import torch
diagnostics["gpu_available"] = torch.cuda.is_available()
if diagnostics["gpu_available"]:
diagnostics["gpu_info"] = torch.cuda.get_device_name(0)
except:
pass
# Check critical files
required_files = [
"adapter_layer.py",
"model_Custm.py",
"model_PrTr.py",
"model_repo_config.py",
"load_model_weights.py",
"service_registry.py"
]
for filename in required_files:
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), filename)
diagnostics["files_present"][filename] = os.path.exists(file_path)
# Check model repo config
try:
from model_repo_config import get_repo_config
repo_config = get_repo_config()
diagnostics["model_repo_config"] = {
"repo_id": repo_config.repo_id,
"weight_locations": repo_config.weight_locations[:3] + ["..."], # First few for brevity
"has_auth_token": repo_config.has_auth_token(),
"cache_dir": repo_config.cache_dir
}
except Exception as e:
diagnostics["model_repo_config_error"] = str(e)
return diagnostics
def main():
"""Main function to parse arguments and run tests"""
parser = argparse.ArgumentParser(description="Comprehensive Wildnerve TLM Model Test Suite")
parser.add_argument("--prompt", type=str, default="Tell me about Malaysia's culture",
help="Prompt text to test (default is non-math to force model loading)")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
parser.add_argument("--check-math", action="store_true", help="Run math capability tests")
parser.add_argument("--check-weights", action="store_true", help="Test model weight loading")
parser.add_argument("--verify-repos", action="store_true", help="Verify repository access")
parser.add_argument("--list-weights", action="store_true", help="List available weight files")
parser.add_argument("--test-load", action="store_true", help="Test loading weights into model")
parser.add_argument("--diagnostics", action="store_true", help="Run system diagnostics")
parser.add_argument("--all", action="store_true", help="Run all tests")
parser.add_argument("--output", type=str, help="Save results to JSON file")
args = parser.parse_args()
# If --all specified, enable all tests
if args.all:
args.check_math = True
args.check_weights = True
args.verify_repos = True
args.list_weights = True
args.test_load = True
args.diagnostics = True
# Track overall execution time
start_time = time.time()
results = {}
# Run diagnostics if requested
if args.diagnostics:
logger.info("Running system diagnostics...")
try:
diagnostics = run_diagnostics()
results["diagnostics"] = diagnostics
if args.verbose:
logger.info("Diagnostic results:")
for category, data in diagnostics.items():
logger.info(f" {category}:")
if isinstance(data, dict):
for key, value in data.items():
logger.info(f" - {key}: {value}")
else:
logger.info(f" - {data}")
except Exception as e:
logger.error(f"Error in diagnostics: {e}")
results["diagnostics_error"] = str(e)
# Verify repository access if requested
if args.verify_repos:
logger.info("Verifying model repository access...")
try:
repo_results = verify_repositories()
results["repository_verification"] = repo_results
# Log summary
logger.info(f"Repositories checked: {repo_results['repositories_checked']}")
logger.info(f"Repositories accessible: {repo_results['repositories_accessible']}")
if args.verbose:
for repo, details in repo_results["details"].items():
status = "β" if details["accessible"] else "β"
logger.info(f" {status} {repo}: {details['num_files']} files")
except Exception as e:
logger.error(f"Error verifying repositories: {e}")
results["repository_verification_error"] = str(e)
# List weight files if requested
if args.list_weights:
logger.info("Listing available weight files...")
try:
weight_files = list_weight_files()
results["weight_files"] = weight_files
# Log summary
logger.info(f"Local weight files found: {len(weight_files['local_weights'])}")
logger.info(f"Repositories with weights: {len(weight_files['repository_weights'])}")
if args.verbose:
# Show local weights
if weight_files["local_weights"]:
logger.info("Local weight files:")
for weight in weight_files["local_weights"]:
logger.info(f" - {weight['relative_path']} ({weight['size_mb']:.1f} MB)")
# Show repository weights
for repo, files in weight_files["repository_weights"].items():
if isinstance(files, list):
logger.info(f"Weights in {repo}: {len(files)} files")
for file in files[:5]: # Show first 5
logger.info(f" - {file}")
if len(files) > 5:
logger.info(f" - ... ({len(files)-5} more)")
except Exception as e:
logger.error(f"Error listing weight files: {e}")
results["weight_files_error"] = str(e)
# Test weight loading if requested
if args.check_weights:
logger.info("Testing model weight loading...")
try:
weight_loading = test_weight_loading()
results["weight_loading"] = weight_loading
# Log summary
if weight_loading["local_weights_found"]:
logger.info("β Local weights found")
for key, path in weight_loading["weight_files"].items():
if path:
logger.info(f" - {key}: {path}")
elif weight_loading["downloaded_weights"]:
logger.info("β Weights downloaded successfully")
for key, path in weight_loading["weight_files"].items():
if path:
logger.info(f" - {key}: {path}")
else:
logger.warning("β No weights found or downloaded")
if weight_loading["errors"]:
for error in weight_loading["errors"]:
logger.warning(f" - Error: {error}")
except Exception as e:
logger.error(f"Error testing weight loading: {e}")
results["weight_loading_error"] = str(e)
# Test loading weights into model if requested
if args.test_load:
logger.info("Testing loading weights into model...")
try:
weight_in_model = test_weight_loading_in_model()
results["weight_in_model"] = weight_in_model
# Log summary
if weight_in_model["success"]:
logger.info("β Successfully loaded weights into model")
logger.info(f" - Weight path: {weight_in_model['weight_path']}")
if "test_inference" in weight_in_model:
logger.info(f" - Test inference: {weight_in_model['test_inference'][:50]}...")
else:
logger.warning("β Failed to load weights into model")
if weight_in_model["error"]:
logger.warning(f" - Error: {weight_in_model['error']}")
except Exception as e:
logger.error(f"Error testing weights in model: {e}")
results["weight_in_model_error"] = str(e)
# Test model loading with the provided prompt
logger.info(f"Testing model loading with prompt: {args.prompt}")
loading_results = test_model_loading(args.prompt, args.verbose)
results["model_loading"] = loading_results
# Summary of model loading test
if loading_results["success"]:
if loading_results["model_loaded"]:
logger.info("β
SUCCESS: Model loaded and generated response")
logger.info(f" - Response: {loading_results['response'][:50]}...")
logger.info(f" - Time: {loading_results['elapsed_time']:.2f} seconds")
else:
logger.warning("β οΈ PARTIAL: Model adapter works but uses fallback (not neural network)")
logger.warning(f" - Fallback response: {loading_results['response'][:50]}...")
else:
logger.error("β FAILED: Could not load the model")
if loading_results["error"]:
logger.error(f" - Error: {loading_results['error']}")
# Run math tests if requested
if args.check_math:
logger.info("Running math capability tests...")
math_results = test_math_capability()
results["math_tests"] = math_results
# Summary of math tests
if math_results["success"]:
logger.info(f"Math tests: {math_results['tests_passed']}/{math_results['tests_total']} passed")
if args.verbose:
for i, test in enumerate(math_results["details"]):
status = "β" if test.get("passed") else "β"
logger.info(f" {status} Test {i+1}: {test['prompt']}")
if not test.get("passed"):
logger.info(f" Expected: {test.get('expected')}")
logger.info(f" Got: {test.get('response', '')[:50]}...")
else:
logger.error("Failed to run math tests")
if "error" in math_results:
logger.error(f" - Error: {math_results['error']}")
# Log total execution time
elapsed = time.time() - start_time
logger.info(f"All tests completed in {elapsed:.2f} seconds")
# Save results if requested
if args.output:
try:
import json
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {args.output}")
except Exception as e:
logger.error(f"Failed to save results: {e}")
if __name__ == "__main__":
main()
|