apollo-astralis-2 / example_usage.py
unmodeled-tyler's picture
Upload folder using huggingface_hub
8fc1812 verified
#!/usr/bin/env python3
"""
Example usage script for Apollo Astralis 2
Demonstrates loading and inference with the model
"""
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, Mistral3ForConditionalGeneration
from peft import PeftModel
def load_apollo_astralis_v2(model_path="vanta-research/apollo-astralis-2"):
"""
Load Apollo Astralis 2 model with 4-bit quantization.
Args:
model_path: Path to the model (HuggingFace repo or local path)
Returns:
model, tokenizer: Loaded model and tokenizer
"""
print("Loading Apollo Astralis 2...")
# Configure 4-bit quantization for memory efficiency
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load base model with quantization
base_model = Mistral3ForConditionalGeneration.from_pretrained(
"Ministral-3-8B-Reasoning-2512",
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, model_path)
model.eval()
print("Model loaded successfully!")
return model, tokenizer
def generate_response(model, tokenizer, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
"""
Generate a response from Apollo Astralis 2.
Args:
model: The loaded model
tokenizer: The loaded tokenizer
prompt: User prompt/question
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.0 = deterministic, 1.0 = random)
top_p: Nucleus sampling parameter
Returns:
str: Generated response
"""
# Format prompt with chat template
messages = [{"role": "user", "content": prompt}]
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else None,
top_p=top_p if temperature > 0 else None,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response (excluding the input prompt)
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response
def main():
"""
Example usage demonstrating various capabilities of Apollo Astralis 2
"""
# Load model
model, tokenizer = load_apollo_astralis_v2()
# Example 1: Logical reasoning
print("\n" + "="*80)
print("EXAMPLE 1: Logical Reasoning")
print("="*80)
prompt1 = "Analyze this argument: If it rains, the streets get wet. The streets are wet. Therefore, it must have rained. Is this reasoning valid?"
print(f"\nPrompt: {prompt1}")
print(f"\nResponse:\n{generate_response(model, tokenizer, prompt1)}")
# Example 2: Mathematical problem solving
print("\n" + "="*80)
print("EXAMPLE 2: Mathematical Problem Solving")
print("="*80)
prompt2 = """
A train travels at 60 mph for 2 hours, then 80 mph for 3 hours.
What is the average speed for the entire journey?
"""
print(f"\nPrompt: {prompt2.strip()}")
print(f"\nResponse:\n{generate_response(model, tokenizer, prompt2)}")
# Example 3: Commonsense reasoning
print("\n" + "="*80)
print("EXAMPLE 3: Commonsense Reasoning")
print("="*80)
prompt3 = """
You need to keep food cold but your refrigerator is broken.
What are some practical solutions?
"""
print(f"\nPrompt: {prompt3.strip()}")
print(f"\nResponse:\n{generate_response(model, tokenizer, prompt3)}")
# Example 4: Physical commonsense
print("\n" + "="*80)
print("EXAMPLE 4: Physical Commonsense")
print("="*80)
prompt4 = """
You have a jar with a tight lid that won't open.
What are effective ways to open it?
"""
print(f"\nPrompt: {prompt4.strip()}")
print(f"\nResponse:\n{generate_response(model, tokenizer, prompt4)}")
print("\n" + "="*80)
print("Examples completed!")
print("="*80)
if __name__ == "__main__":
main()