|
|
"""
|
|
|
Example usage of KerdosAgent for training and inference.
|
|
|
"""
|
|
|
|
|
|
from kerdosai.agent import KerdosAgent
|
|
|
import logging
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def example_basic_usage():
|
|
|
"""Basic usage example."""
|
|
|
print("=" * 50)
|
|
|
print("Example 1: Basic Initialization and Generation")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
agent = KerdosAgent(
|
|
|
base_model="gpt2",
|
|
|
training_data=None
|
|
|
)
|
|
|
|
|
|
|
|
|
info = agent.get_model_info()
|
|
|
print(f"\nModel Info:")
|
|
|
print(f" Total parameters: {info['total_parameters']:,}")
|
|
|
print(f" Trainable parameters: {info['trainable_parameters']:,}")
|
|
|
print(f" Device: {info['device']}")
|
|
|
|
|
|
|
|
|
prompt = "The future of artificial intelligence is"
|
|
|
print(f"\nPrompt: {prompt}")
|
|
|
generated = agent.generate(
|
|
|
prompt=prompt,
|
|
|
max_length=50,
|
|
|
temperature=0.7
|
|
|
)
|
|
|
print(f"Generated: {generated}")
|
|
|
|
|
|
|
|
|
def example_with_quantization():
|
|
|
"""Example with 4-bit quantization."""
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Example 2: Initialization with Quantization")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
agent = KerdosAgent(
|
|
|
base_model="gpt2",
|
|
|
training_data=None,
|
|
|
use_4bit=True
|
|
|
)
|
|
|
|
|
|
info = agent.get_model_info()
|
|
|
print(f"\nModel loaded with quantization")
|
|
|
print(f" Model dtype: {info['model_dtype']}")
|
|
|
print(f" Device: {info['device']}")
|
|
|
|
|
|
|
|
|
def example_batch_inference():
|
|
|
"""Example of batch inference."""
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Example 3: Batch Inference")
|
|
|
print("=" * 50)
|
|
|
|
|
|
agent = KerdosAgent(
|
|
|
base_model="gpt2",
|
|
|
training_data=None
|
|
|
)
|
|
|
|
|
|
|
|
|
prompts = [
|
|
|
"Once upon a time",
|
|
|
"In the year 2050",
|
|
|
"The secret to happiness is"
|
|
|
]
|
|
|
|
|
|
print("\nRunning batch inference...")
|
|
|
results = agent.inference(
|
|
|
texts=prompts,
|
|
|
batch_size=2,
|
|
|
max_length=30
|
|
|
)
|
|
|
|
|
|
print("\nResults:")
|
|
|
for prompt, result in zip(prompts, results):
|
|
|
print(f" Prompt: {prompt}")
|
|
|
print(f" Result: {result}\n")
|
|
|
|
|
|
|
|
|
def example_training_preparation():
|
|
|
"""Example of preparing model for training with LoRA."""
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Example 4: Prepare Model for Training with LoRA")
|
|
|
print("=" * 50)
|
|
|
|
|
|
agent = KerdosAgent(
|
|
|
base_model="gpt2",
|
|
|
training_data="data/sample_data.csv"
|
|
|
)
|
|
|
|
|
|
print("\nBefore LoRA:")
|
|
|
info_before = agent.get_model_info()
|
|
|
print(f" Trainable parameters: {info_before['trainable_parameters']:,}")
|
|
|
print(f" Trainable %: {info_before['trainable_percentage']:.2f}%")
|
|
|
|
|
|
|
|
|
print("\nApplying LoRA...")
|
|
|
agent.prepare_for_training(
|
|
|
use_lora=True,
|
|
|
lora_r=8,
|
|
|
lora_alpha=32,
|
|
|
lora_dropout=0.1
|
|
|
)
|
|
|
|
|
|
print("\nAfter LoRA:")
|
|
|
info_after = agent.get_model_info()
|
|
|
print(f" Trainable parameters: {info_after['trainable_parameters']:,}")
|
|
|
print(f" Trainable %: {info_after['trainable_percentage']:.2f}%")
|
|
|
|
|
|
|
|
|
def example_save_and_load():
|
|
|
"""Example of saving and loading a model."""
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Example 5: Save and Load Model")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
print("\nInitializing model...")
|
|
|
agent = KerdosAgent(
|
|
|
base_model="gpt2",
|
|
|
training_data=None
|
|
|
)
|
|
|
|
|
|
|
|
|
output_dir = "models/example_model"
|
|
|
print(f"\nSaving model to {output_dir}...")
|
|
|
agent.save(output_dir)
|
|
|
print("Model saved successfully!")
|
|
|
|
|
|
|
|
|
print(f"\nLoading model from {output_dir}...")
|
|
|
loaded_agent = KerdosAgent.load(output_dir)
|
|
|
print("Model loaded successfully!")
|
|
|
|
|
|
|
|
|
result = loaded_agent.generate("Hello", max_length=20)
|
|
|
print(f"\nGeneration test: {result}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("\n" + "=" * 50)
|
|
|
print("KerdosAgent Usage Examples")
|
|
|
print("=" * 50)
|
|
|
|
|
|
try:
|
|
|
|
|
|
example_basic_usage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Examples completed successfully!")
|
|
|
print("=" * 50)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\nError running examples: {str(e)}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
|