|
|
"""
|
|
|
Main KerdosAgent class that orchestrates the training and deployment process.
|
|
|
"""
|
|
|
|
|
|
from typing import Optional, Union, Dict, Any, List
|
|
|
from pathlib import Path
|
|
|
import torch
|
|
|
import logging
|
|
|
from transformers import (
|
|
|
AutoModelForCausalLM,
|
|
|
AutoTokenizer,
|
|
|
BitsAndBytesConfig,
|
|
|
GenerationConfig
|
|
|
)
|
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
import warnings
|
|
|
|
|
|
from .trainer import Trainer
|
|
|
from .deployer import Deployer
|
|
|
from .data_processor import DataProcessor
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KerdosAgent:
|
|
|
"""
|
|
|
Main agent class for training and deploying LLMs with custom data.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
base_model: str,
|
|
|
training_data: Union[str, Path],
|
|
|
device: Optional[str] = None,
|
|
|
**kwargs
|
|
|
):
|
|
|
"""
|
|
|
Initialize the KerdosAgent.
|
|
|
|
|
|
Args:
|
|
|
base_model: Name or path of the base LLM model
|
|
|
training_data: Path to the training data
|
|
|
device: Device to run the model on (cuda/cpu)
|
|
|
**kwargs: Additional configuration parameters
|
|
|
"""
|
|
|
self.base_model = base_model
|
|
|
self.training_data = Path(training_data) if training_data else None
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.config = kwargs
|
|
|
|
|
|
logger.info(f"Initializing KerdosAgent with base model: {base_model}")
|
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
|
|
|
|
|
self._validate_config()
|
|
|
|
|
|
|
|
|
try:
|
|
|
quantization_config = self._get_quantization_config()
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
base_model,
|
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
|
device_map="auto",
|
|
|
quantization_config=quantization_config,
|
|
|
trust_remote_code=kwargs.get('trust_remote_code', False)
|
|
|
)
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
base_model,
|
|
|
trust_remote_code=kwargs.get('trust_remote_code', False)
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
self.model.config.pad_token_id = self.model.config.eos_token_id
|
|
|
|
|
|
|
|
|
if self.training_data:
|
|
|
self.data_processor = DataProcessor(self.training_data)
|
|
|
else:
|
|
|
self.data_processor = None
|
|
|
|
|
|
self.trainer = Trainer(self.model, self.tokenizer, self.device)
|
|
|
self.deployer = Deployer(self.model, self.tokenizer)
|
|
|
|
|
|
logger.info("KerdosAgent initialized successfully")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing KerdosAgent: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def train(
|
|
|
self,
|
|
|
epochs: int = 3,
|
|
|
batch_size: int = 4,
|
|
|
learning_rate: float = 2e-5,
|
|
|
**kwargs
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Train the model on the provided data.
|
|
|
|
|
|
Args:
|
|
|
epochs: Number of training epochs
|
|
|
batch_size: Training batch size
|
|
|
learning_rate: Learning rate for training
|
|
|
**kwargs: Additional training parameters
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing training metrics
|
|
|
"""
|
|
|
|
|
|
train_dataset = self.data_processor.prepare_dataset()
|
|
|
|
|
|
|
|
|
training_args = {
|
|
|
"epochs": epochs,
|
|
|
"batch_size": batch_size,
|
|
|
"learning_rate": learning_rate,
|
|
|
**kwargs
|
|
|
}
|
|
|
|
|
|
metrics = self.trainer.train(train_dataset, **training_args)
|
|
|
return metrics
|
|
|
|
|
|
def deploy(
|
|
|
self,
|
|
|
deployment_type: str = "rest",
|
|
|
host: str = "0.0.0.0",
|
|
|
port: int = 8000,
|
|
|
**kwargs
|
|
|
) -> None:
|
|
|
"""
|
|
|
Deploy the trained model.
|
|
|
|
|
|
Args:
|
|
|
deployment_type: Type of deployment (rest/docker/kubernetes)
|
|
|
host: Host address for REST API
|
|
|
port: Port number for REST API
|
|
|
**kwargs: Additional deployment parameters
|
|
|
"""
|
|
|
deployment_args = {
|
|
|
"deployment_type": deployment_type,
|
|
|
"host": host,
|
|
|
"port": port,
|
|
|
**kwargs
|
|
|
}
|
|
|
|
|
|
self.deployer.deploy(**deployment_args)
|
|
|
|
|
|
def save(self, output_dir: Union[str, Path]) -> None:
|
|
|
"""
|
|
|
Save the trained model and tokenizer.
|
|
|
|
|
|
Args:
|
|
|
output_dir: Directory to save the model
|
|
|
"""
|
|
|
output_dir = Path(output_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
self.model.save_pretrained(output_dir)
|
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
|
|
def generate(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
max_length: int = 100,
|
|
|
temperature: float = 0.7,
|
|
|
top_p: float = 0.9,
|
|
|
top_k: int = 50,
|
|
|
num_return_sequences: int = 1,
|
|
|
**kwargs
|
|
|
) -> Union[str, List[str]]:
|
|
|
"""
|
|
|
Generate text from a prompt.
|
|
|
|
|
|
Args:
|
|
|
prompt: Input text prompt
|
|
|
max_length: Maximum length of generated text
|
|
|
temperature: Sampling temperature
|
|
|
top_p: Nucleus sampling parameter
|
|
|
top_k: Top-k sampling parameter
|
|
|
num_return_sequences: Number of sequences to generate
|
|
|
**kwargs: Additional generation parameters
|
|
|
|
|
|
Returns:
|
|
|
Generated text or list of generated texts
|
|
|
"""
|
|
|
try:
|
|
|
logger.info(f"Generating text from prompt: {prompt[:50]}...")
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
prompt,
|
|
|
return_tensors="pt",
|
|
|
padding=True,
|
|
|
truncation=True
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig(
|
|
|
max_length=max_length,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
top_k=top_k,
|
|
|
num_return_sequences=num_return_sequences,
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.generate(
|
|
|
**inputs,
|
|
|
generation_config=generation_config
|
|
|
)
|
|
|
|
|
|
|
|
|
generated_texts = [
|
|
|
self.tokenizer.decode(output, skip_special_tokens=True)
|
|
|
for output in outputs
|
|
|
]
|
|
|
|
|
|
logger.info(f"Generated {len(generated_texts)} sequence(s)")
|
|
|
|
|
|
return generated_texts[0] if num_return_sequences == 1 else generated_texts
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating text: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def inference(
|
|
|
self,
|
|
|
texts: List[str],
|
|
|
batch_size: int = 8,
|
|
|
**kwargs
|
|
|
) -> List[str]:
|
|
|
"""
|
|
|
Run batch inference on multiple texts.
|
|
|
|
|
|
Args:
|
|
|
texts: List of input texts
|
|
|
batch_size: Batch size for inference
|
|
|
**kwargs: Additional generation parameters
|
|
|
|
|
|
Returns:
|
|
|
List of generated texts
|
|
|
"""
|
|
|
try:
|
|
|
logger.info(f"Running inference on {len(texts)} texts")
|
|
|
|
|
|
results = []
|
|
|
self.model.eval()
|
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
batch = texts[i:i + batch_size]
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
batch,
|
|
|
return_tensors="pt",
|
|
|
padding=True,
|
|
|
truncation=True
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.generate(
|
|
|
**inputs,
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
batch_results = [
|
|
|
self.tokenizer.decode(output, skip_special_tokens=True)
|
|
|
for output in outputs
|
|
|
]
|
|
|
results.extend(batch_results)
|
|
|
|
|
|
logger.info(f"Inference completed for {len(results)} texts")
|
|
|
return results
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during inference: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def prepare_for_training(
|
|
|
self,
|
|
|
use_lora: bool = True,
|
|
|
lora_r: int = 8,
|
|
|
lora_alpha: int = 32,
|
|
|
lora_dropout: float = 0.1,
|
|
|
target_modules: Optional[List[str]] = None,
|
|
|
use_4bit: bool = False,
|
|
|
use_8bit: bool = False
|
|
|
) -> None:
|
|
|
"""
|
|
|
Prepare the model for efficient training using LoRA and/or quantization.
|
|
|
|
|
|
Args:
|
|
|
use_lora: Whether to use LoRA (Low-Rank Adaptation)
|
|
|
lora_r: LoRA rank
|
|
|
lora_alpha: LoRA alpha parameter
|
|
|
lora_dropout: LoRA dropout rate
|
|
|
target_modules: List of module names to apply LoRA to
|
|
|
use_4bit: Whether to use 4-bit quantization
|
|
|
use_8bit: Whether to use 8-bit quantization
|
|
|
"""
|
|
|
try:
|
|
|
logger.info("Preparing model for training")
|
|
|
|
|
|
|
|
|
if use_4bit or use_8bit:
|
|
|
logger.info("Preparing model for k-bit training")
|
|
|
self.model = prepare_model_for_kbit_training(self.model)
|
|
|
|
|
|
|
|
|
if use_lora:
|
|
|
logger.info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}")
|
|
|
|
|
|
if target_modules is None:
|
|
|
|
|
|
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
|
|
|
|
lora_config = LoraConfig(
|
|
|
r=lora_r,
|
|
|
lora_alpha=lora_alpha,
|
|
|
target_modules=target_modules,
|
|
|
lora_dropout=lora_dropout,
|
|
|
bias="none",
|
|
|
task_type="CAUSAL_LM"
|
|
|
)
|
|
|
|
|
|
self.model = get_peft_model(self.model, lora_config)
|
|
|
self.model.print_trainable_parameters()
|
|
|
|
|
|
logger.info("Model prepared for training successfully")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error preparing model for training: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
def _validate_config(self) -> None:
|
|
|
"""
|
|
|
Validate the agent configuration.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If configuration is invalid
|
|
|
"""
|
|
|
if not self.base_model:
|
|
|
raise ValueError("base_model must be specified")
|
|
|
|
|
|
if self.config.get('use_4bit') and self.config.get('use_8bit'):
|
|
|
raise ValueError("Cannot use both 4-bit and 8-bit quantization")
|
|
|
|
|
|
logger.debug("Configuration validated successfully")
|
|
|
|
|
|
def _get_quantization_config(self) -> Optional[BitsAndBytesConfig]:
|
|
|
"""
|
|
|
Get quantization configuration if requested.
|
|
|
|
|
|
Returns:
|
|
|
BitsAndBytesConfig or None
|
|
|
"""
|
|
|
if self.config.get('use_4bit'):
|
|
|
logger.info("Using 4-bit quantization")
|
|
|
return BitsAndBytesConfig(
|
|
|
load_in_4bit=True,
|
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
|
bnb_4bit_quant_type="nf4",
|
|
|
bnb_4bit_use_double_quant=True
|
|
|
)
|
|
|
elif self.config.get('use_8bit'):
|
|
|
logger.info("Using 8-bit quantization")
|
|
|
return BitsAndBytesConfig(
|
|
|
load_in_8bit=True
|
|
|
)
|
|
|
return None
|
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get information about the current model.
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing model information
|
|
|
"""
|
|
|
total_params = sum(p.numel() for p in self.model.parameters())
|
|
|
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
|
|
|
|
return {
|
|
|
"base_model": self.base_model,
|
|
|
"device": self.device,
|
|
|
"total_parameters": total_params,
|
|
|
"trainable_parameters": trainable_params,
|
|
|
"trainable_percentage": (trainable_params / total_params) * 100 if total_params > 0 else 0,
|
|
|
"model_dtype": str(next(self.model.parameters()).dtype),
|
|
|
"config": self.config
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
|
def load(cls, model_dir: Union[str, Path], **kwargs) -> "KerdosAgent":
|
|
|
"""
|
|
|
Load a trained model from disk.
|
|
|
|
|
|
Args:
|
|
|
model_dir: Directory containing the saved model
|
|
|
**kwargs: Additional initialization parameters
|
|
|
|
|
|
Returns:
|
|
|
Loaded KerdosAgent instance
|
|
|
"""
|
|
|
try:
|
|
|
model_dir = Path(model_dir)
|
|
|
|
|
|
if not model_dir.exists():
|
|
|
raise FileNotFoundError(f"Model directory {model_dir} does not exist")
|
|
|
|
|
|
logger.info(f"Loading model from {model_dir}")
|
|
|
|
|
|
|
|
|
agent = cls(
|
|
|
base_model=str(model_dir),
|
|
|
training_data=None,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
logger.info("Model loaded successfully")
|
|
|
return agent
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading model: {str(e)}")
|
|
|
raise |