kerdosai / agent.py
bhaskarvilles's picture
Added Locally
3df89a1 verified
"""
Main KerdosAgent class that orchestrates the training and deployment process.
"""
from typing import Optional, Union, Dict, Any, List
from pathlib import Path
import torch
import logging
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import warnings
from .trainer import Trainer
from .deployer import Deployer
from .data_processor import DataProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class KerdosAgent:
"""
Main agent class for training and deploying LLMs with custom data.
"""
def __init__(
self,
base_model: str,
training_data: Union[str, Path],
device: Optional[str] = None,
**kwargs
):
"""
Initialize the KerdosAgent.
Args:
base_model: Name or path of the base LLM model
training_data: Path to the training data
device: Device to run the model on (cuda/cpu)
**kwargs: Additional configuration parameters
"""
self.base_model = base_model
self.training_data = Path(training_data) if training_data else None
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.config = kwargs
logger.info(f"Initializing KerdosAgent with base model: {base_model}")
logger.info(f"Using device: {self.device}")
# Validate configuration
self._validate_config()
# Initialize components
try:
quantization_config = self._get_quantization_config()
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=kwargs.get('trust_remote_code', False)
)
self.tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=kwargs.get('trust_remote_code', False)
)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.model.config.eos_token_id
# Initialize other components
if self.training_data:
self.data_processor = DataProcessor(self.training_data)
else:
self.data_processor = None
self.trainer = Trainer(self.model, self.tokenizer, self.device)
self.deployer = Deployer(self.model, self.tokenizer)
logger.info("KerdosAgent initialized successfully")
except Exception as e:
logger.error(f"Error initializing KerdosAgent: {str(e)}")
raise
def train(
self,
epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
**kwargs
) -> Dict[str, Any]:
"""
Train the model on the provided data.
Args:
epochs: Number of training epochs
batch_size: Training batch size
learning_rate: Learning rate for training
**kwargs: Additional training parameters
Returns:
Dictionary containing training metrics
"""
# Process training data
train_dataset = self.data_processor.prepare_dataset()
# Train the model
training_args = {
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
**kwargs
}
metrics = self.trainer.train(train_dataset, **training_args)
return metrics
def deploy(
self,
deployment_type: str = "rest",
host: str = "0.0.0.0",
port: int = 8000,
**kwargs
) -> None:
"""
Deploy the trained model.
Args:
deployment_type: Type of deployment (rest/docker/kubernetes)
host: Host address for REST API
port: Port number for REST API
**kwargs: Additional deployment parameters
"""
deployment_args = {
"deployment_type": deployment_type,
"host": host,
"port": port,
**kwargs
}
self.deployer.deploy(**deployment_args)
def save(self, output_dir: Union[str, Path]) -> None:
"""
Save the trained model and tokenizer.
Args:
output_dir: Directory to save the model
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
num_return_sequences: int = 1,
**kwargs
) -> Union[str, List[str]]:
"""
Generate text from a prompt.
Args:
prompt: Input text prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
num_return_sequences: Number of sequences to generate
**kwargs: Additional generation parameters
Returns:
Generated text or list of generated texts
"""
try:
logger.info(f"Generating text from prompt: {prompt[:50]}...")
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
# Set up generation config
generation_config = GenerationConfig(
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
**kwargs
)
# Generate
self.model.eval()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
generation_config=generation_config
)
# Decode outputs
generated_texts = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
logger.info(f"Generated {len(generated_texts)} sequence(s)")
return generated_texts[0] if num_return_sequences == 1 else generated_texts
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
raise
def inference(
self,
texts: List[str],
batch_size: int = 8,
**kwargs
) -> List[str]:
"""
Run batch inference on multiple texts.
Args:
texts: List of input texts
batch_size: Batch size for inference
**kwargs: Additional generation parameters
Returns:
List of generated texts
"""
try:
logger.info(f"Running inference on {len(texts)} texts")
results = []
self.model.eval()
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Tokenize batch
inputs = self.tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
# Decode
batch_results = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
results.extend(batch_results)
logger.info(f"Inference completed for {len(results)} texts")
return results
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
raise
def prepare_for_training(
self,
use_lora: bool = True,
lora_r: int = 8,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
target_modules: Optional[List[str]] = None,
use_4bit: bool = False,
use_8bit: bool = False
) -> None:
"""
Prepare the model for efficient training using LoRA and/or quantization.
Args:
use_lora: Whether to use LoRA (Low-Rank Adaptation)
lora_r: LoRA rank
lora_alpha: LoRA alpha parameter
lora_dropout: LoRA dropout rate
target_modules: List of module names to apply LoRA to
use_4bit: Whether to use 4-bit quantization
use_8bit: Whether to use 8-bit quantization
"""
try:
logger.info("Preparing model for training")
# Prepare model for k-bit training if quantization is used
if use_4bit or use_8bit:
logger.info("Preparing model for k-bit training")
self.model = prepare_model_for_kbit_training(self.model)
# Apply LoRA if requested
if use_lora:
logger.info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}")
if target_modules is None:
# Default target modules for common architectures
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.model = get_peft_model(self.model, lora_config)
self.model.print_trainable_parameters()
logger.info("Model prepared for training successfully")
except Exception as e:
logger.error(f"Error preparing model for training: {str(e)}")
raise
def _validate_config(self) -> None:
"""
Validate the agent configuration.
Raises:
ValueError: If configuration is invalid
"""
if not self.base_model:
raise ValueError("base_model must be specified")
if self.config.get('use_4bit') and self.config.get('use_8bit'):
raise ValueError("Cannot use both 4-bit and 8-bit quantization")
logger.debug("Configuration validated successfully")
def _get_quantization_config(self) -> Optional[BitsAndBytesConfig]:
"""
Get quantization configuration if requested.
Returns:
BitsAndBytesConfig or None
"""
if self.config.get('use_4bit'):
logger.info("Using 4-bit quantization")
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
elif self.config.get('use_8bit'):
logger.info("Using 8-bit quantization")
return BitsAndBytesConfig(
load_in_8bit=True
)
return None
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the current model.
Returns:
Dictionary containing model information
"""
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return {
"base_model": self.base_model,
"device": self.device,
"total_parameters": total_params,
"trainable_parameters": trainable_params,
"trainable_percentage": (trainable_params / total_params) * 100 if total_params > 0 else 0,
"model_dtype": str(next(self.model.parameters()).dtype),
"config": self.config
}
@classmethod
def load(cls, model_dir: Union[str, Path], **kwargs) -> "KerdosAgent":
"""
Load a trained model from disk.
Args:
model_dir: Directory containing the saved model
**kwargs: Additional initialization parameters
Returns:
Loaded KerdosAgent instance
"""
try:
model_dir = Path(model_dir)
if not model_dir.exists():
raise FileNotFoundError(f"Model directory {model_dir} does not exist")
logger.info(f"Loading model from {model_dir}")
# Create agent with loaded model
agent = cls(
base_model=str(model_dir),
training_data=None, # Not needed for loading
**kwargs
)
logger.info("Model loaded successfully")
return agent
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise