""" Main KerdosAgent class that orchestrates the training and deployment process. """ from typing import Optional, Union, Dict, Any, List from pathlib import Path import torch import logging from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training import warnings from .trainer import Trainer from .deployer import Deployer from .data_processor import DataProcessor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class KerdosAgent: """ Main agent class for training and deploying LLMs with custom data. """ def __init__( self, base_model: str, training_data: Union[str, Path], device: Optional[str] = None, **kwargs ): """ Initialize the KerdosAgent. Args: base_model: Name or path of the base LLM model training_data: Path to the training data device: Device to run the model on (cuda/cpu) **kwargs: Additional configuration parameters """ self.base_model = base_model self.training_data = Path(training_data) if training_data else None self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.config = kwargs logger.info(f"Initializing KerdosAgent with base model: {base_model}") logger.info(f"Using device: {self.device}") # Validate configuration self._validate_config() # Initialize components try: quantization_config = self._get_quantization_config() self.model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto", quantization_config=quantization_config, trust_remote_code=kwargs.get('trust_remote_code', False) ) self.tokenizer = AutoTokenizer.from_pretrained( base_model, trust_remote_code=kwargs.get('trust_remote_code', False) ) # Set pad token if not present if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model.config.pad_token_id = self.model.config.eos_token_id # Initialize other components if self.training_data: self.data_processor = DataProcessor(self.training_data) else: self.data_processor = None self.trainer = Trainer(self.model, self.tokenizer, self.device) self.deployer = Deployer(self.model, self.tokenizer) logger.info("KerdosAgent initialized successfully") except Exception as e: logger.error(f"Error initializing KerdosAgent: {str(e)}") raise def train( self, epochs: int = 3, batch_size: int = 4, learning_rate: float = 2e-5, **kwargs ) -> Dict[str, Any]: """ Train the model on the provided data. Args: epochs: Number of training epochs batch_size: Training batch size learning_rate: Learning rate for training **kwargs: Additional training parameters Returns: Dictionary containing training metrics """ # Process training data train_dataset = self.data_processor.prepare_dataset() # Train the model training_args = { "epochs": epochs, "batch_size": batch_size, "learning_rate": learning_rate, **kwargs } metrics = self.trainer.train(train_dataset, **training_args) return metrics def deploy( self, deployment_type: str = "rest", host: str = "0.0.0.0", port: int = 8000, **kwargs ) -> None: """ Deploy the trained model. Args: deployment_type: Type of deployment (rest/docker/kubernetes) host: Host address for REST API port: Port number for REST API **kwargs: Additional deployment parameters """ deployment_args = { "deployment_type": deployment_type, "host": host, "port": port, **kwargs } self.deployer.deploy(**deployment_args) def save(self, output_dir: Union[str, Path]) -> None: """ Save the trained model and tokenizer. Args: output_dir: Directory to save the model """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) self.model.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir) def generate( self, prompt: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, num_return_sequences: int = 1, **kwargs ) -> Union[str, List[str]]: """ Generate text from a prompt. Args: prompt: Input text prompt max_length: Maximum length of generated text temperature: Sampling temperature top_p: Nucleus sampling parameter top_k: Top-k sampling parameter num_return_sequences: Number of sequences to generate **kwargs: Additional generation parameters Returns: Generated text or list of generated texts """ try: logger.info(f"Generating text from prompt: {prompt[:50]}...") # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True ).to(self.device) # Set up generation config generation_config = GenerationConfig( max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, **kwargs ) # Generate self.model.eval() with torch.no_grad(): outputs = self.model.generate( **inputs, generation_config=generation_config ) # Decode outputs generated_texts = [ self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] logger.info(f"Generated {len(generated_texts)} sequence(s)") return generated_texts[0] if num_return_sequences == 1 else generated_texts except Exception as e: logger.error(f"Error generating text: {str(e)}") raise def inference( self, texts: List[str], batch_size: int = 8, **kwargs ) -> List[str]: """ Run batch inference on multiple texts. Args: texts: List of input texts batch_size: Batch size for inference **kwargs: Additional generation parameters Returns: List of generated texts """ try: logger.info(f"Running inference on {len(texts)} texts") results = [] self.model.eval() for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] # Tokenize batch inputs = self.tokenizer( batch, return_tensors="pt", padding=True, truncation=True ).to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, pad_token_id=self.tokenizer.pad_token_id, **kwargs ) # Decode batch_results = [ self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] results.extend(batch_results) logger.info(f"Inference completed for {len(results)} texts") return results except Exception as e: logger.error(f"Error during inference: {str(e)}") raise def prepare_for_training( self, use_lora: bool = True, lora_r: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.1, target_modules: Optional[List[str]] = None, use_4bit: bool = False, use_8bit: bool = False ) -> None: """ Prepare the model for efficient training using LoRA and/or quantization. Args: use_lora: Whether to use LoRA (Low-Rank Adaptation) lora_r: LoRA rank lora_alpha: LoRA alpha parameter lora_dropout: LoRA dropout rate target_modules: List of module names to apply LoRA to use_4bit: Whether to use 4-bit quantization use_8bit: Whether to use 8-bit quantization """ try: logger.info("Preparing model for training") # Prepare model for k-bit training if quantization is used if use_4bit or use_8bit: logger.info("Preparing model for k-bit training") self.model = prepare_model_for_kbit_training(self.model) # Apply LoRA if requested if use_lora: logger.info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}") if target_modules is None: # Default target modules for common architectures target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"] lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" ) self.model = get_peft_model(self.model, lora_config) self.model.print_trainable_parameters() logger.info("Model prepared for training successfully") except Exception as e: logger.error(f"Error preparing model for training: {str(e)}") raise def _validate_config(self) -> None: """ Validate the agent configuration. Raises: ValueError: If configuration is invalid """ if not self.base_model: raise ValueError("base_model must be specified") if self.config.get('use_4bit') and self.config.get('use_8bit'): raise ValueError("Cannot use both 4-bit and 8-bit quantization") logger.debug("Configuration validated successfully") def _get_quantization_config(self) -> Optional[BitsAndBytesConfig]: """ Get quantization configuration if requested. Returns: BitsAndBytesConfig or None """ if self.config.get('use_4bit'): logger.info("Using 4-bit quantization") return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) elif self.config.get('use_8bit'): logger.info("Using 8-bit quantization") return BitsAndBytesConfig( load_in_8bit=True ) return None def get_model_info(self) -> Dict[str, Any]: """ Get information about the current model. Returns: Dictionary containing model information """ total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) return { "base_model": self.base_model, "device": self.device, "total_parameters": total_params, "trainable_parameters": trainable_params, "trainable_percentage": (trainable_params / total_params) * 100 if total_params > 0 else 0, "model_dtype": str(next(self.model.parameters()).dtype), "config": self.config } @classmethod def load(cls, model_dir: Union[str, Path], **kwargs) -> "KerdosAgent": """ Load a trained model from disk. Args: model_dir: Directory containing the saved model **kwargs: Additional initialization parameters Returns: Loaded KerdosAgent instance """ try: model_dir = Path(model_dir) if not model_dir.exists(): raise FileNotFoundError(f"Model directory {model_dir} does not exist") logger.info(f"Loading model from {model_dir}") # Create agent with loaded model agent = cls( base_model=str(model_dir), training_data=None, # Not needed for loading **kwargs ) logger.info("Model loaded successfully") return agent except Exception as e: logger.error(f"Error loading model: {str(e)}") raise