VivekanandaAI / utils.py
jyotirmoy05's picture
Upload 8 files
e889148 verified
"""
Vivekananda AI - Core Utilities
Pure logic, config-driven implementation with no hardcoding
Handles configuration, logging, device management, and prompt building
"""
import yaml
import logging
import logging.config
import colorlog
from pathlib import Path
from typing import Dict, Any, Optional, List
# Make PyTorch optional so GGUF runner can work without Torch
try:
import torch # type: ignore
TORCH_AVAILABLE = True
except Exception:
TORCH_AVAILABLE = False
torch = None # sentinel for guards
import json
import os
from datetime import datetime
import gc
import requests
from huggingface_hub import HfApi, Repository, hf_hub_download
import shutil
# ============================================================================
# CONFIGURATION MANAGER
# ============================================================================
class Config:
"""Centralized configuration manager with no hardcoding"""
def __init__(self, config_path: Optional[Path] = None):
self.config_path = config_path or Path(__file__).parent / "config.yaml"
self.config_data = {}
self.load_config()
def load_config(self):
"""Load configuration from YAML file"""
if not self.config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config_data = yaml.safe_load(f)
print(f"✅ Configuration loaded from: {self.config_path}")
except Exception as e:
raise RuntimeError(f"Failed to load config: {e}")
def get(self, key_path: str, default: Any = None) -> Any:
"""Get nested configuration value using dot notation"""
keys = key_path.split('.')
value = self.config_data
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
return default
return value
def get_path(self, *path_keys: str) -> Optional[Path]:
"""Get Path object from configuration"""
path_str = self.get('.'.join(path_keys))
if path_str:
return Path(path_str).resolve()
return None
def set(self, key_path: str, value: Any):
"""Set configuration value using dot notation"""
keys = key_path.split('.')
target = self.config_data
for key in keys[:-1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = value
def update(self, updates: Dict[str, Any]):
"""Update multiple configuration values"""
for key_path, value in updates.items():
self.set(key_path, value)
def save(self, path: Optional[Path] = None):
"""Save current configuration to file"""
save_path = path or self.config_path
try:
with open(save_path, 'w', encoding='utf-8') as f:
yaml.dump(self.config_data, f, default_flow_style=False, allow_unicode=True)
print(f"💾 Configuration saved to: {save_path}")
except Exception as e:
raise RuntimeError(f"Failed to save config: {e}")
# ============================================================================
# DEVICE MANAGER (MPS OPTIMIZED FOR APPLE SILICON)
# ============================================================================
class DeviceManager:
"""Handle device detection and optimization for Apple Silicon MPS"""
def __init__(self, config: Config):
self.config = config
self.device = None
self.torch_dtype = None
self.detect_device()
def detect_device(self):
"""Auto-detect best available device with MPS priority"""
# Get preferred device from config
preferred_device = self.config.get('hardware.device', 'auto').lower()
if preferred_device == 'auto':
# Auto-detection logic (guarded if Torch present)
if TORCH_AVAILABLE and torch.backends.mps.is_available():
self.device = 'mps'
print("🍎 MPS (Apple Silicon) detected and enabled")
elif TORCH_AVAILABLE and torch.cuda.is_available():
self.device = 'cuda'
print("⚡ CUDA GPU detected and enabled")
else:
self.device = 'cpu'
print("💻 CPU mode (Torch unavailable or no GPU acceleration)")
else:
# Use configured device
if preferred_device == 'mps' and (not TORCH_AVAILABLE or not torch.backends.mps.is_available()):
fallback = self.config.get('hardware.fallback_device', 'cpu')
print(f"⚠️ MPS not available, falling back to {fallback}")
self.device = fallback
elif preferred_device == 'cuda' and (not TORCH_AVAILABLE or not torch.cuda.is_available()):
fallback = self.config.get('hardware.fallback_device', 'cpu')
print(f"⚠️ CUDA not available, falling back to {fallback}")
self.device = fallback
else:
self.device = preferred_device
print(f"✅ Using configured device: {self.device}")
# Set torch dtype
dtype_str = self.config.get('hardware.torch_dtype', 'float32')
if TORCH_AVAILABLE:
dtype_map = {
'float32': torch.float32,
'float16': torch.float16,
'bfloat16': torch.bfloat16
}
self.torch_dtype = dtype_map.get(dtype_str, torch.float32)
else:
# Store dtype as a simple string when Torch is not present
self.torch_dtype = dtype_str
# Set PyTorch settings
if TORCH_AVAILABLE and self.device == 'mps':
# MPS-specific optimizations
torch.mps.set_per_process_memory_fraction(0.8)
print("🎯 MPS memory optimization enabled")
# Log device info
self.log_device_info()
def log_device_info(self):
"""Log detailed device information"""
print(f"\n{'='*50}")
print(f"DEVICE INFORMATION")
print(f"{'='*50}")
print(f"Device: {self.device}")
if TORCH_AVAILABLE:
print(f"PyTorch dtype: {self.torch_dtype}")
if self.device == 'mps':
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")
elif self.device == 'cuda':
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
print("PyTorch not available; using CPU mode with generic dtype")
print(f"{'='*50}\n")
def get_torch_device(self):
"""Get PyTorch device object"""
if not TORCH_AVAILABLE:
return None
if self.device == 'mps':
return torch.device('mps')
elif self.device == 'cuda':
return torch.device('cuda')
else:
return torch.device('cpu')
def optimize_memory(self):
"""Memory optimization for current device"""
if TORCH_AVAILABLE:
if self.device == 'mps':
torch.mps.empty_cache()
elif self.device == 'cuda':
torch.cuda.empty_cache()
gc.collect()
print("🧹 Memory optimized")
# ============================================================================
# FILE HANDLER
# ============================================================================
class FileHandler:
"""Handle all file operations with proper error handling"""
def __init__(self, config: Config, logger: logging.Logger):
self.config = config
self.logger = logger
self.ensure_directories()
def ensure_directories(self):
"""Create all required directories"""
dir_configs = [
'paths.data.root',
'paths.data.raw',
'paths.data.processed',
'paths.data.extracted',
'paths.vectorstore.root',
'paths.models.root',
'paths.models.base',
'paths.models.fine_tuned',
'paths.outputs.root',
'paths.outputs.logs',
'paths.outputs.results'
]
for config_path in dir_configs:
path = self.config.get_path(*config_path.split('.'))
if path:
path.mkdir(parents=True, exist_ok=True)
def load_json(self, file_path: Path) -> Any:
"""Load JSON file with error handling"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
self.logger.error(f"Failed to load JSON {file_path}: {e}")
return None
def save_json(self, data: Any, file_path: Path, indent: int = 2):
"""Save data to JSON file"""
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=indent, ensure_ascii=False)
self.logger.info(f"💾 Saved JSON to: {file_path}")
except Exception as e:
self.logger.error(f"Failed to save JSON {file_path}: {e}")
def get_files_by_extension(self, directory: Path, extensions: List[str]) -> List[Path]:
"""Get all files with specified extensions from directory"""
if not directory.exists():
self.logger.warning(f"Directory not found: {directory}")
return []
files = []
for ext in extensions:
files.extend(directory.glob(f"*{ext}"))
return sorted(files)
def read_text_file(self, file_path: Path, encoding: str = 'utf-8') -> str:
"""Read text file with error handling"""
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except Exception as e:
self.logger.error(f"Failed to read {file_path}: {e}")
return ""
def write_text_file(self, content: str, file_path: Path, encoding: str = 'utf-8'):
"""Write content to text file"""
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding=encoding) as f:
f.write(content)
self.logger.info(f"📝 Saved text to: {file_path}")
except Exception as e:
self.logger.error(f"Failed to write {file_path}: {e}")
# ============================================================================
# PROMPT BUILDER
# ============================================================================
class PromptBuilder:
"""Build prompts for Swami Vivekananda AI with no hardcoding"""
def __init__(self, config: Config):
self.config = config
self.system_prompt = self.load_system_prompt()
self.rag_template = self.load_rag_template()
self.direct_template = self.load_direct_template()
def load_system_prompt(self) -> str:
"""Load system prompt from config"""
prompt = self.config.get('prompts.system', "")
if not prompt:
# Fallback system prompt
prompt = """You are an AI embodying the wisdom and teachings of Swami Vivekananda.
Speak with clarity, strength, compassion, and spiritual insight.
Draw from Vedanta philosophy and emphasize universal truths."""
return prompt.strip()
def load_rag_template(self) -> dict:
"""Load RAG prompt template from config"""
template = self.config.get('prompts.rag_template', {})
if not template:
# Fallback RAG template
template = {
"header": "Context from Swami Vivekananda's works:\n{context}\n\nQuestion: {question}",
"footer": "Based on the provided context and Swami Vivekananda's teachings, \nplease provide a thoughtful response that reflects his wisdom and philosophy."
}
return template
def load_direct_template(self) -> dict:
"""Load direct prompt template"""
template = self.config.get('prompts.direct_template', {})
if not template:
# Fallback direct template
template = {
"template": "Question: {question}\n\nPlease provide a response in the spirit of Swami Vivekananda's teachings."
}
return template
def build_rag_prompt(self, question: str, context: str) -> str:
"""Build RAG prompt with context"""
header = self.rag_template.get('header', '').format(context=context, question=question)
footer = self.rag_template.get('footer', '')
return f"{header}\n\n{footer}"
def build_direct_prompt(self, question: str) -> str:
"""Build direct prompt without context"""
return self.direct_template.get('template', '').format(question=question)
def get_full_prompt(self, question: str, context: Optional[str] = None) -> str:
"""Get complete prompt with system message"""
if context:
user_message = self.build_rag_prompt(question, context)
else:
user_message = self.build_direct_prompt(question)
# Combine system and user messages
full_prompt = f"System: {self.system_prompt}\n\n{user_message}"
return full_prompt
def get_system_prompt(self) -> str:
"""Get system prompt"""
return self.system_prompt
# ============================================================================
# LOGGER SETUP
# ============================================================================
class LoggerSetup:
"""Setup colored logging with config-driven settings"""
def __init__(self, config: Config):
self.config = config
self.setup_logging()
def setup_logging(self):
"""Configure logging with colors"""
log_level = self.config.get('logging.level', 'INFO').upper()
log_format = self.config.get('logging.format',
'%(log_color)s%(asctime)s | %(levelname)-8s | %(name)s | %(message)s')
# Color configuration
log_colors = {
'DEBUG': 'cyan',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'red,bg_white',
}
# Configure colorlog
colorlog.basicConfig(
level=getattr(logging, log_level),
format=log_format,
log_colors=log_colors,
datefmt='%H:%M:%S'
)
# Reduce verbosity of external libraries
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('transformers').setLevel(logging.WARNING)
logging.getLogger('langchain').setLevel(logging.INFO)
# ============================================================================
# HUGGINGFACE DATASET EXTRACTOR
# ============================================================================
class HuggingFaceDatasetExtractor:
"""Extract datasets from Hugging Face including PDF files"""
def __init__(self, config: Config, logger: logging.Logger = None):
self.config = config
self.logger = logger or logging.getLogger(__name__)
self.api = HfApi()
def authenticate(self, token: str = None):
"""Authenticate with Hugging Face"""
if token:
os.environ['HF_TOKEN'] = token
elif 'HF_TOKEN' not in os.environ:
self.logger.warning("No Hugging Face token provided. Some datasets may be restricted.")
def download_dataset_files(self, dataset_id: str, target_dir: Path, file_extensions: List[str] = None) -> List[Path]:
"""Download specific file types from a Hugging Face dataset"""
if file_extensions is None:
file_extensions = ['.pdf', '.txt', '.md']
target_dir.mkdir(parents=True, exist_ok=True)
downloaded_files = []
try:
self.logger.info(f"Downloading files from dataset: {dataset_id}")
# Get dataset info
dataset_info = self.api.dataset_info(dataset_id)
# Download files with specified extensions
for sibling in dataset_info.siblings:
file_path = sibling.rfilename
file_ext = Path(file_path).suffix.lower()
if file_ext in file_extensions:
try:
# Download file
local_path = target_dir / Path(file_path).name
downloaded_path = hf_hub_download(
repo_id=dataset_id,
filename=file_path,
repo_type="dataset",
local_dir=target_dir,
local_dir_use_symlinks=False
)
# Move to desired location if needed
if Path(downloaded_path) != local_path:
shutil.move(downloaded_path, local_path)
downloaded_files.append(local_path)
self.logger.info(f"Downloaded: {file_path}")
except Exception as e:
self.logger.error(f"Failed to download {file_path}: {e}")
self.logger.info(f"Successfully downloaded {len(downloaded_files)} files to {target_dir}")
return downloaded_files
except Exception as e:
self.logger.error(f"Error downloading dataset {dataset_id}: {e}")
return []
def search_vivekananda_datasets(self, query: str = "vivekananda") -> List[Dict[str, Any]]:
"""Search for Vivekananda-related datasets on Hugging Face"""
try:
# Search for datasets
datasets = self.api.list_datasets(
filter=f"{query}",
sort="downloads",
direction=-1,
limit=20
)
results = []
for dataset in datasets:
dataset_info = {
'id': dataset.id,
'downloads': getattr(dataset, 'downloads', 0),
'tags': getattr(dataset, 'tags', []),
'description': getattr(dataset, 'card_data', {}).get('description', 'No description')
}
results.append(dataset_info)
return results
except Exception as e:
self.logger.error(f"Error searching datasets: {e}")
return []
def extract_pdfs_from_datasets(self, dataset_ids: List[str], target_dir: Path) -> Dict[str, List[Path]]:
"""Extract PDFs from multiple datasets"""
results = {}
for dataset_id in dataset_ids:
self.logger.info(f"Processing dataset: {dataset_id}")
dataset_dir = target_dir / dataset_id.replace('/', '_')
files = self.download_dataset_files(dataset_id, dataset_dir, ['.pdf'])
results[dataset_id] = files
return results
# ============================================================================
# MAIN UTILITIES CLASS
# ============================================================================
class Utils:
"""Main utilities class that coordinates all components"""
def __init__(self, config_path: Optional[Path] = None):
# Initialize components
self.config = Config(config_path)
self.logger_setup = LoggerSetup(self.config)
self.logger = logging.getLogger(self.__class__.__name__)
self.device_manager = DeviceManager(self.config)
self.file_handler = FileHandler(self.config, self.logger)
self.prompt_builder = PromptBuilder(self.config)
self.hf_extractor = HuggingFaceDatasetExtractor(self.config, self.logger)
# Log initialization
self.logger.info("🚀 Vivekananda AI Utils initialized")
self.logger.info(f"Project: {self.config.get('project.name', 'Vivekananda AI')}")
self.logger.info(f"Version: {self.config.get('project.version', '1.0.0')}")
def get_config(self) -> Config:
"""Get configuration manager"""
return self.config
def get_logger(self, name: Optional[str] = None) -> logging.Logger:
"""Get logger instance"""
if name:
return logging.getLogger(name)
return self.logger
def get_device_manager(self) -> DeviceManager:
"""Get device manager"""
return self.device_manager
def get_file_handler(self) -> FileHandler:
"""Get file handler"""
return self.file_handler
def get_prompt_builder(self) -> PromptBuilder:
"""Get prompt builder"""
return self.prompt_builder
def optimize_memory(self):
"""Optimize memory usage"""
self.device_manager.optimize_memory()
def get_system_info(self) -> Dict[str, Any]:
"""Get system information"""
return {
'device': self.device_manager.device,
'torch_dtype': str(self.device_manager.torch_dtype),
'config_path': str(self.config.config_path),
'project_name': self.config.get('project.name'),
'project_version': self.config.get('project.version')
}
# ============================================================================
# GLOBAL INSTANCE
# ============================================================================
# Global utils instance for easy access
_global_utils = None
def get_utils(config_path: Optional[Path] = None) -> Utils:
"""Get or create global utils instance"""
global _global_utils
if _global_utils is None:
_global_utils = Utils(config_path)
return _global_utils
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def format_timestamp() -> str:
"""Get formatted timestamp"""
return datetime.now().strftime("%Y%m%d_%H%M%S")
def safe_filename(filename: str) -> str:
"""Create safe filename by removing/replacing invalid characters"""
# Remove or replace invalid characters
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename = filename.replace(char, '_')
# Remove leading/trailing dots and spaces
filename = filename.strip('. ')
# Ensure not empty
if not filename:
filename = "unnamed"
return filename
def estimate_tokens(text: str) -> int:
"""Rough estimation of token count (1 token ≈ 4 characters)"""
return len(text) // 4
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
"""Split text into overlapping chunks"""
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
# Move start position with overlap
start = end - overlap
# Prevent infinite loop
if start >= len(text) - overlap:
break
return chunks
# ============================================================================
# ERROR HANDLING
# ============================================================================
class VivekanandaAIError(Exception):
"""Base exception for Vivekananda AI"""
pass
class ConfigurationError(VivekanandaAIError):
"""Configuration-related errors"""
pass
class ModelError(VivekanandaAIError):
"""Model-related errors"""
pass
class DeviceError(VivekanandaAIError):
"""Device-related errors"""
pass
# ============================================================================
# MAIN EXECUTION (for testing)
# ============================================================================
if __name__ == "__main__":
# Test utilities
print("🧪 Testing Vivekananda AI Utilities...")
# Initialize utils
utils = get_utils()
# Test configuration
print(f"\n📋 Configuration test:")
print(f"Project: {utils.config.get('project.name')}")
print(f"Device: {utils.device_manager.device}")
# Test prompt builder
print(f"\n💬 Prompt test:")
test_prompt = utils.prompt_builder.get_full_prompt("What is the meaning of life?")
print(f"Sample prompt length: {len(test_prompt)} characters")
# Test file handler
print(f"\n📁 File handler test:")
test_file = utils.config.get_path('paths.outputs.logs') / "test.log"
utils.file_handler.write_text_file("Test log entry", test_file)
# System info
print(f"\n🔍 System info:")
info = utils.get_system_info()
for key, value in info.items():
print(f" {key}: {value}")
print("\n✅ Utilities test completed!")