File size: 4,711 Bytes
8fc1812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""
Example usage script for Apollo Astralis 2
Demonstrates loading and inference with the model
"""

import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, Mistral3ForConditionalGeneration
from peft import PeftModel

def load_apollo_astralis_v2(model_path="vanta-research/apollo-astralis-2"):
    """
    Load Apollo Astralis 2 model with 4-bit quantization.
    
    Args:
        model_path: Path to the model (HuggingFace repo or local path)
    
    Returns:
        model, tokenizer: Loaded model and tokenizer
    """
    print("Loading Apollo Astralis 2...")
    
    # Configure 4-bit quantization for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    # Load base model with quantization
    base_model = Mistral3ForConditionalGeneration.from_pretrained(
        "Ministral-3-8B-Reasoning-2512",
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.float16,
    )
    
    # Load LoRA adapter
    model = PeftModel.from_pretrained(base_model, model_path)
    model.eval()
    
    print("Model loaded successfully!")
    return model, tokenizer

def generate_response(model, tokenizer, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
    """
    Generate a response from Apollo Astralis 2.
    
    Args:
        model: The loaded model
        tokenizer: The loaded tokenizer
        prompt: User prompt/question
        max_new_tokens: Maximum tokens to generate
        temperature: Sampling temperature (0.0 = deterministic, 1.0 = random)
        top_p: Nucleus sampling parameter
    
    Returns:
        str: Generated response
    """
    # Format prompt with chat template
    messages = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Tokenize input
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature if temperature > 0 else None,
            top_p=top_p if temperature > 0 else None,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode response (excluding the input prompt)
    response = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:], 
        skip_special_tokens=True
    )
    
    return response

def main():
    """
    Example usage demonstrating various capabilities of Apollo Astralis 2
    """
    # Load model
    model, tokenizer = load_apollo_astralis_v2()
    
    # Example 1: Logical reasoning
    print("\n" + "="*80)
    print("EXAMPLE 1: Logical Reasoning")
    print("="*80)
    prompt1 = "Analyze this argument: If it rains, the streets get wet. The streets are wet. Therefore, it must have rained. Is this reasoning valid?"
    print(f"\nPrompt: {prompt1}")
    print(f"\nResponse:\n{generate_response(model, tokenizer, prompt1)}")
    
    # Example 2: Mathematical problem solving
    print("\n" + "="*80)
    print("EXAMPLE 2: Mathematical Problem Solving")
    print("="*80)
    prompt2 = """
    A train travels at 60 mph for 2 hours, then 80 mph for 3 hours. 
    What is the average speed for the entire journey?
    """
    print(f"\nPrompt: {prompt2.strip()}")
    print(f"\nResponse:\n{generate_response(model, tokenizer, prompt2)}")
    
    # Example 3: Commonsense reasoning
    print("\n" + "="*80)
    print("EXAMPLE 3: Commonsense Reasoning")
    print("="*80)
    prompt3 = """
    You need to keep food cold but your refrigerator is broken.
    What are some practical solutions?
    """
    print(f"\nPrompt: {prompt3.strip()}")
    print(f"\nResponse:\n{generate_response(model, tokenizer, prompt3)}")
    
    # Example 4: Physical commonsense
    print("\n" + "="*80)
    print("EXAMPLE 4: Physical Commonsense")
    print("="*80)
    prompt4 = """
    You have a jar with a tight lid that won't open.
    What are effective ways to open it?
    """
    print(f"\nPrompt: {prompt4.strip()}")
    print(f"\nResponse:\n{generate_response(model, tokenizer, prompt4)}")
    
    print("\n" + "="*80)
    print("Examples completed!")
    print("="*80)

if __name__ == "__main__":
    main()