Spaces:
Sleeping
Sleeping
File size: 5,960 Bytes
5d9390a ab4f445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper
from .base import BaseModel
class TransformersModel(BaseModel):
def __init__(self,
device='cuda',
model_path=None,
**kwargs):
self.device = device
# Use local model if path provided, otherwise use HuggingFace
model_name_or_path = model_path if model_path else 'ekwek/Soprano-1.1-80M'
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
device_map=device
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model.eval()
def infer(self,
prompts,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
inputs = self.tokenizer(
prompts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=512,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_hidden_states=True,
)
res = []
eos_token_id = self.model.config.eos_token_id
for i in range(len(prompts)):
seq = outputs.sequences[i]
hidden_states = []
num_output_tokens = len(outputs.hidden_states)
for j in range(num_output_tokens):
token = seq[j + seq.size(0) - num_output_tokens]
if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
last_hidden_state = torch.stack(hidden_states).squeeze()
finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
res.append({
'finish_reason': finish_reason,
'hidden_state': last_hidden_state
})
return res
def stream_infer(self,
prompt,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
input_ids = inputs['input_ids']
# Prepare Logits Processors for sampling
logits_processor = LogitsProcessorList()
if repetition_penalty != 1.0:
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
logits_warper = LogitsProcessorList()
if temperature != 1.0:
logits_warper.append(TemperatureLogitsWarper(temperature=temperature))
if top_p < 1.0:
logits_warper.append(TopPLogitsWarper(top_p=top_p))
# Helper to sample next token
def get_next_token(logits, input_seq):
scores = logits_processor(input_seq, logits)
scores = logits_warper(input_seq, scores)
probs = torch.nn.functional.softmax(scores, dim=-1)
# Sample from the distribution
return torch.multinomial(probs, num_samples=1)
with torch.no_grad():
# Initial forward pass with the prompt
outputs = self.model(
input_ids,
use_cache=True,
output_hidden_states=True
)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
# We need to maintain the full sequence for repetition penalty
generated_ids = input_ids
# Sample the first token
next_token = get_next_token(next_token_logits, generated_ids)
max_new_tokens = 512
eos_token_id = self.model.config.eos_token_id
for i in range(max_new_tokens):
# Append generated token to sequence history
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Run forward pass for the single new token
outputs = self.model(
next_token,
past_key_values=past_key_values,
use_cache=True,
output_hidden_states=True
)
# Update cache and get hidden state
past_key_values = outputs.past_key_values
current_hidden_state = outputs.hidden_states[-1][:, -1, :] # Last layer, last token
finish_reason = None
if next_token.item() == eos_token_id:
finish_reason = 'stop'
elif i == max_new_tokens - 1:
finish_reason = 'length'
# Yield result matching lmdeploy format
yield {
'finish_reason': finish_reason,
'hidden_state': current_hidden_state
}
if finish_reason:
break
# Prepare for next iteration
next_token_logits = outputs.logits[:, -1, :]
next_token = get_next_token(next_token_logits, generated_ids) |