apollo-astralis-4b / example_usage.py
Tyler Williams
Initial release: Apollo-Astralis V1 4B with Apache 2.0
2c40ce7
"""
Apollo-Astralis V1 4B - Example Usage
This script demonstrates how to use Apollo-Astralis V1 4B with Transformers.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def load_model(model_name="VANTA-Research/apollo-astralis-v1-4b"):
"""Load Apollo-Astralis model and tokenizer."""
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
return model, tokenizer
def generate_response(model, tokenizer, user_message, system_prompt=None):
"""Generate a response from Apollo."""
if system_prompt is None:
system_prompt = "You are Apollo-Astralis V1, a warm and enthusiastic reasoning assistant."
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.05
)
# Decode
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response
def main():
# Load model
model, tokenizer = load_model()
# Example 1: Celebration
print("\n" + "="*60)
print("Example 1: Celebration Response")
print("="*60)
user_msg = "I just got my first job as a software engineer!"
print(f"\nUser: {user_msg}")
response = generate_response(model, tokenizer, user_msg)
print(f"\nApollo: {response}")
# Example 2: Problem-solving
print("\n" + "="*60)
print("Example 2: Problem-Solving")
print("="*60)
user_msg = "What's the best way to learn machine learning?"
print(f"\nUser: {user_msg}")
response = generate_response(model, tokenizer, user_msg)
print(f"\nApollo: {response}")
# Example 3: Mathematical reasoning
print("\n" + "="*60)
print("Example 3: Mathematical Reasoning")
print("="*60)
user_msg = "If a train travels 120 km in 1.5 hours, what's its average speed?"
print(f"\nUser: {user_msg}")
response = generate_response(model, tokenizer, user_msg)
print(f"\nApollo: {response}")
if __name__ == "__main__":
main()