File size: 5,960 Bytes
5d9390a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4f445
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper
from .base import BaseModel


class TransformersModel(BaseModel):
    def __init__(self,

            device='cuda',

            model_path=None,

            **kwargs):
        self.device = device
        
        # Use local model if path provided, otherwise use HuggingFace
        model_name_or_path = model_path if model_path else 'ekwek/Soprano-1.1-80M'
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
            device_map=device
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model.eval()

    def infer(self,

            prompts,

            top_p=0.95,

            temperature=0.3,

            repetition_penalty=1.2):
        inputs = self.tokenizer(
            prompts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512,
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=512,
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                pad_token_id=self.tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_hidden_states=True,
            )
        res = []
        eos_token_id = self.model.config.eos_token_id
        for i in range(len(prompts)):
            seq = outputs.sequences[i]
            hidden_states = []
            num_output_tokens = len(outputs.hidden_states)
            for j in range(num_output_tokens):
                token = seq[j + seq.size(0) - num_output_tokens]
                if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
            last_hidden_state = torch.stack(hidden_states).squeeze()
            finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
            res.append({
                'finish_reason': finish_reason,
                'hidden_state': last_hidden_state
            })
        return res

    def stream_infer(self,

            prompt,

            top_p=0.95,

            temperature=0.3,

            repetition_penalty=1.2):
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
        input_ids = inputs['input_ids']
        
        # Prepare Logits Processors for sampling
        logits_processor = LogitsProcessorList()
        if repetition_penalty != 1.0:
            logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
        
        logits_warper = LogitsProcessorList()
        if temperature != 1.0:
            logits_warper.append(TemperatureLogitsWarper(temperature=temperature))
        if top_p < 1.0:
            logits_warper.append(TopPLogitsWarper(top_p=top_p))

        # Helper to sample next token
        def get_next_token(logits, input_seq):
            scores = logits_processor(input_seq, logits)
            scores = logits_warper(input_seq, scores)
            probs = torch.nn.functional.softmax(scores, dim=-1)
            # Sample from the distribution
            return torch.multinomial(probs, num_samples=1)

        with torch.no_grad():
            # Initial forward pass with the prompt
            outputs = self.model(
                input_ids,
                use_cache=True,
                output_hidden_states=True
            )
            
            past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            
            # We need to maintain the full sequence for repetition penalty
            generated_ids = input_ids
            
            # Sample the first token
            next_token = get_next_token(next_token_logits, generated_ids)
            
            max_new_tokens = 512
            eos_token_id = self.model.config.eos_token_id
            
            for i in range(max_new_tokens):
                # Append generated token to sequence history
                generated_ids = torch.cat([generated_ids, next_token], dim=-1)
                
                # Run forward pass for the single new token
                outputs = self.model(
                    next_token,
                    past_key_values=past_key_values,
                    use_cache=True,
                    output_hidden_states=True
                )
                
                # Update cache and get hidden state
                past_key_values = outputs.past_key_values
                current_hidden_state = outputs.hidden_states[-1][:, -1, :] # Last layer, last token
                
                finish_reason = None
                if next_token.item() == eos_token_id:
                    finish_reason = 'stop'
                elif i == max_new_tokens - 1:
                    finish_reason = 'length'

                # Yield result matching lmdeploy format
                yield {
                    'finish_reason': finish_reason,
                    'hidden_state': current_hidden_state
                }
                
                if finish_reason:
                    break
                
                # Prepare for next iteration
                next_token_logits = outputs.logits[:, -1, :]
                next_token = get_next_token(next_token_logits, generated_ids)