WildnerveAI's picture
Upload 20 files
0861a59 verified
import os
import sys
import time
import psutil
import traceback
import logging
import threading
from typing import Dict, Any, Optional, List
logger = logging.getLogger(__name__)
class DebugHelper:
"""
Helper class for debugging hanging issues in STDP training.
This provides tools for:
1. Process monitoring and status reporting
2. Timeout management
3. Recovery mechanisms for hanging processes
4. Detailed diagnostics
"""
@staticmethod
def get_process_info(pid: Optional[int] = None) -> Dict[str, Any]:
"""Get detailed information about the current process."""
pid = pid or os.getpid()
try:
process = psutil.Process(pid)
# Get basic process info
info = {
'pid': pid,
'name': process.name(),
'status': process.status(),
'cpu_percent': process.cpu_percent(),
'memory_percent': process.memory_percent(),
'memory_info': dict(process.memory_info()._asdict()),
'create_time': process.create_time(),
'runtime': time.time() - process.create_time(),
'num_threads': process.num_threads(),
'open_files': len(process.open_files()),
'connections': len(process.connections()),
}
# Get thread details
try:
import threading
info['active_threads'] = [t.name for t in threading.enumerate()]
except:
info['active_threads'] = "Could not retrieve thread information"
# Get current stack trace
info['current_stack'] = traceback.format_stack()
return info
except Exception as e:
logger.error(f"Error getting process info: {e}")
return {'error': str(e)}
@staticmethod
def check_resource_leaks() -> Dict[str, Any]:
"""Check for potential resource leaks."""
import gc
leaks = {
'gc_counts': gc.get_count(),
'gc_objects': len(gc.get_objects()),
}
# Check for torch memory usage if available
try:
import torch
if torch.cuda.is_available():
leaks['torch_memory_allocated'] = torch.cuda.memory_allocated()
leaks['torch_memory_reserved'] = torch.cuda.memory_reserved()
leaks['torch_max_memory_allocated'] = torch.cuda.max_memory_allocated()
except ImportError:
pass
return leaks
@staticmethod
def register_timeout(seconds: int, callback=None):
"""Register a timeout that calls the callback after specified seconds."""
def _timeout_handler():
time.sleep(seconds)
if callback:
callback()
else:
print(f"TIMEOUT: Operation took longer than {seconds} seconds")
info = DebugHelper.get_process_info()
print(f"Process info: {info}")
traceback.print_stack()
thread = threading.Thread(target=_timeout_handler)
thread.daemon = True
thread.start()
return thread
@staticmethod
def dump_debug_info(filename: str):
"""Dump debug information to a file."""
process_info = DebugHelper.get_process_info()
leak_info = DebugHelper.check_resource_leaks()
with open(filename, 'w') as f:
f.write("===== PROCESS INFORMATION =====\n")
for key, value in process_info.items():
f.write(f"{key}: {value}\n")
f.write("\n===== RESOURCE LEAK INFORMATION =====\n")
for key, value in leak_info.items():
f.write(f"{key}: {value}\n")
f.write("\n===== ENVIRONMENT VARIABLES =====\n")
for key, value in os.environ.items():
f.write(f"{key}: {value}\n")
f.write("\n===== STACK TRACE =====\n")
f.write(''.join(traceback.format_stack()))
logger.info(f"Debug info dumped to {filename}")