HNTAI / scripts /verify_cache.py
sachinchandrankallar's picture
changes for publishing the latest including generate_generic api
4156c57
#!/usr/bin/env python3
"""
Verify that models are properly cached and accessible.
Run this after deployment to ensure everything is working.
"""
import os
import sys
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def check_directory(path, name):
"""Check if directory exists and contains files"""
if not os.path.exists(path):
logger.error(f"❌ {name} directory not found: {path}")
return False
# Count files
file_count = sum(1 for _ in Path(path).rglob('*') if _.is_file())
# Calculate size
total_size = sum(
f.stat().st_size
for f in Path(path).rglob('*')
if f.is_file()
)
size_gb = total_size / (1024**3)
if file_count == 0:
logger.warning(f"⚠️ {name} directory is empty: {path}")
return False
logger.info(f"βœ… {name}: {file_count} files, {size_gb:.2f} GB")
return True
def verify_transformers_cache():
"""Verify transformers models are cached"""
hf_home = os.environ.get('HF_HOME', '/app/.cache/huggingface')
logger.info("\nπŸ” Checking Transformers cache...")
# Check for model files
model_files = list(Path(hf_home).rglob('*.bin')) + \
list(Path(hf_home).rglob('*.safetensors'))
if not model_files:
logger.error("❌ No model files found in HF cache")
return False
logger.info(f"βœ… Found {len(model_files)} model weight files")
# List some models
model_dirs = set()
for f in model_files[:10]: # Show first 10
# Extract model name from path
parts = str(f).split('/')
if 'models--' in str(f):
model_name = [p for p in parts if p.startswith('models--')]
if model_name:
model_dirs.add(model_name[0].replace('models--', '').replace('_', '/'))
logger.info("πŸ“¦ Cached models:")
for model in sorted(model_dirs):
logger.info(f" - {model}")
return True
def verify_gguf_cache():
"""Verify GGUF models are cached"""
model_cache = os.environ.get('MODEL_CACHE_DIR', '/app/models')
hf_home = os.environ.get('HF_HOME', '/app/.cache/huggingface')
logger.info("\nπŸ” Checking GGUF cache...")
# Check both locations
gguf_files = list(Path(model_cache).rglob('*.gguf')) + \
list(Path(hf_home).rglob('*.gguf'))
if not gguf_files:
logger.warning("⚠️ No GGUF files found")
return False
logger.info(f"βœ… Found {len(gguf_files)} GGUF files:")
for f in gguf_files:
size_mb = f.stat().st_size / (1024**2)
logger.info(f" - {f.name} ({size_mb:.1f} MB)")
return True
def verify_whisper_cache():
"""Verify Whisper models are cached"""
whisper_cache = os.environ.get('WHISPER_CACHE', '/app/.cache/whisper')
logger.info("\nπŸ” Checking Whisper cache...")
if not os.path.exists(whisper_cache):
logger.warning(f"⚠️ Whisper cache directory not found: {whisper_cache}")
return False
whisper_files = list(Path(whisper_cache).rglob('*.pt'))
if not whisper_files:
logger.warning("⚠️ No Whisper model files found")
return False
logger.info(f"βœ… Found {len(whisper_files)} Whisper models:")
for f in whisper_files:
logger.info(f" - {f.name}")
return True
def verify_python_imports():
"""Verify critical Python packages can be imported"""
logger.info("\nπŸ” Checking Python imports...")
packages = [
('torch', 'PyTorch'),
('transformers', 'Transformers'),
('whisper', 'Whisper'),
('spacy', 'spaCy'),
('nltk', 'NLTK'),
('fastapi', 'FastAPI'),
]
all_ok = True
for package, name in packages:
try:
__import__(package)
logger.info(f"βœ… {name} import OK")
except ImportError as e:
logger.error(f"❌ {name} import failed: {e}")
all_ok = False
return all_ok
def check_gpu():
"""Check GPU availability"""
logger.info("\nπŸ” Checking GPU...")
try:
import torch
cuda_available = torch.cuda.is_available()
if cuda_available:
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"βœ… GPU available: {gpu_name}")
logger.info(f" GPU Memory: {gpu_memory:.1f} GB")
else:
logger.warning("⚠️ No GPU available, will use CPU")
return True
except Exception as e:
logger.error(f"❌ Error checking GPU: {e}")
return False
def main():
"""Main verification function"""
logger.info("="*80)
logger.info("MODEL CACHE VERIFICATION")
logger.info("="*80)
# Check environment variables
logger.info("\nπŸ“‹ Environment variables:")
env_vars = ['HF_HOME', 'MODEL_CACHE_DIR', 'TORCH_HOME', 'WHISPER_CACHE', 'SPACE_ID']
for var in env_vars:
value = os.environ.get(var, 'NOT SET')
logger.info(f" {var}: {value}")
# Run checks
checks = [
("HF Cache", lambda: check_directory(
os.environ.get('HF_HOME', '/app/.cache/huggingface'),
"Hugging Face Cache"
)),
("Model Cache", lambda: check_directory(
os.environ.get('MODEL_CACHE_DIR', '/app/models'),
"Model Cache"
)),
("Transformers Models", verify_transformers_cache),
("GGUF Models", verify_gguf_cache),
("Whisper Models", verify_whisper_cache),
("Python Imports", verify_python_imports),
("GPU", check_gpu),
]
results = {}
for name, check_func in checks:
try:
results[name] = check_func()
except Exception as e:
logger.error(f"❌ {name} check failed: {e}")
results[name] = False
# Summary
logger.info("\n" + "="*80)
logger.info("SUMMARY")
logger.info("="*80)
passed = sum(1 for v in results.values() if v)
total = len(results)
for name, result in results.items():
status = "βœ… PASS" if result else "❌ FAIL"
logger.info(f"{status}: {name}")
logger.info(f"\nTotal: {passed}/{total} checks passed")
if passed == total:
logger.info("\nπŸŽ‰ All checks passed! Models are properly cached and ready.")
return 0
else:
logger.warning(f"\n⚠️ {total - passed} checks failed. Review the errors above.")
return 1
if __name__ == "__main__":
sys.exit(main())