File size: 5,114 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn.functional as F
from typing import Optional, List, Generator
from aetheris.config import AetherisConfig
from aetheris.model import HybridMambaMoE
from aetheris.data import get_tokenizer
from aetheris.utils import load_latest_checkpoint

class InferenceEngine:
    def __init__(self, config_path: str = "configs/default.yaml", checkpoint_dir: str = "checkpoints", checkpoint_name: str = "checkpoint_current.pth", device: str = None):
        self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
        self.config = AetherisConfig.from_yaml(config_path)
        self.tokenizer = get_tokenizer()
        
        self.model = HybridMambaMoE(self.config).to(self.device).to(self.config.torch_dtype)
        
        # Load checkpoint
        # Note: load_latest_checkpoint expects optimizer and scaler, but for inference we can pass None
        load_latest_checkpoint(self.model, None, None, self.device, checkpoint_dir, checkpoint_name)
        self.model.eval()

    def generate(self, 
                 prompt: str, 
                 max_new_tokens: int = 100, 
                 temperature: float = 0.8, 
                 top_k: int = 0, 
                 top_p: float = 0.9, 
                 repetition_penalty: float = 1.0,
                 stream: bool = False) -> Generator[str, None, None] | str:
        
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        generated_ids = input_ids.clone()
        history_ids = set(input_ids[0].tolist())

        def token_generator():
            nonlocal generated_ids
            for _ in range(max_new_tokens):
                 # Check if we should use autocast (skip if model uses float32)
                use_autocast = True
                if self.config.torch_dtype == torch.float32:
                    use_autocast = False
                
                if use_autocast:
                    with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=self.model.config.torch_dtype):
                        outputs = self.model(generated_ids)
                        logits = outputs['logits']
                        next_token_logits = logits[:, -1, :]
                else:
                    outputs = self.model(generated_ids)
                    logits = outputs['logits']
                    next_token_logits = logits[:, -1, :]

                # Repetition penalty
                for token_id in history_ids:
                    if token_id < next_token_logits.size(-1):
                        logit = next_token_logits[0, token_id].item()
                        if logit > 0:
                            next_token_logits[0, token_id] = logit / repetition_penalty
                        else:
                            next_token_logits[0, token_id] = logit * repetition_penalty

                # Temperature
                if temperature > 0:
                    next_token_logits = next_token_logits / temperature

                # Top-p / Top-k
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = False
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    next_token_logits.scatter_(1, indices_to_remove.unsqueeze(0), float('-inf'))
                elif top_k > 0:
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                    next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                    next_token_logits.scatter_(1, top_k_indices, top_k_logits)

                # Sample
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(next_token_probs, num_samples=1)
                next_token_item = next_token.item()

                if next_token_item == self.tokenizer.eos_token_id:
                    break

                generated_ids = torch.cat([generated_ids, next_token], dim=-1)
                history_ids.add(next_token_item)

                new_token_text = self.tokenizer.decode(next_token.squeeze().tolist(), skip_special_tokens=True)
                yield new_token_text
        
        if stream:
            return token_generator()
        else:
            return "".join(list(token_generator()))

    def generate_full(self, 
                 prompt: str, 
                 max_new_tokens: int = 100, 
                 temperature: float = 0.8, 
                 top_k: int = 0, 
                 top_p: float = 0.9, 
                 repetition_penalty: float = 1.0) -> str:
        return self.generate(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream=False)