Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper | |
| from .base import BaseModel | |
| class TransformersModel(BaseModel): | |
| def __init__(self, | |
| device='cuda', | |
| model_path=None, | |
| **kwargs): | |
| self.device = device | |
| # Use local model if path provided, otherwise use HuggingFace | |
| model_name_or_path = model_path if model_path else 'ekwek/Soprano-1.1-80M' | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| dtype=torch.bfloat16 if device == 'cuda' else torch.float32, | |
| device_map=device | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| self.model.eval() | |
| def infer(self, | |
| prompts, | |
| top_p=0.95, | |
| temperature=0.3, | |
| repetition_penalty=1.2): | |
| inputs = self.tokenizer( | |
| prompts, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| max_new_tokens=512, | |
| do_sample=True, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| return_dict_in_generate=True, | |
| output_hidden_states=True, | |
| ) | |
| res = [] | |
| eos_token_id = self.model.config.eos_token_id | |
| for i in range(len(prompts)): | |
| seq = outputs.sequences[i] | |
| hidden_states = [] | |
| num_output_tokens = len(outputs.hidden_states) | |
| for j in range(num_output_tokens): | |
| token = seq[j + seq.size(0) - num_output_tokens] | |
| if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :]) | |
| last_hidden_state = torch.stack(hidden_states).squeeze() | |
| finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length' | |
| res.append({ | |
| 'finish_reason': finish_reason, | |
| 'hidden_state': last_hidden_state | |
| }) | |
| return res | |
| def stream_infer(self, | |
| prompt, | |
| top_p=0.95, | |
| temperature=0.3, | |
| repetition_penalty=1.2): | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device) | |
| input_ids = inputs['input_ids'] | |
| # Prepare Logits Processors for sampling | |
| logits_processor = LogitsProcessorList() | |
| if repetition_penalty != 1.0: | |
| logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) | |
| logits_warper = LogitsProcessorList() | |
| if temperature != 1.0: | |
| logits_warper.append(TemperatureLogitsWarper(temperature=temperature)) | |
| if top_p < 1.0: | |
| logits_warper.append(TopPLogitsWarper(top_p=top_p)) | |
| # Helper to sample next token | |
| def get_next_token(logits, input_seq): | |
| scores = logits_processor(input_seq, logits) | |
| scores = logits_warper(input_seq, scores) | |
| probs = torch.nn.functional.softmax(scores, dim=-1) | |
| # Sample from the distribution | |
| return torch.multinomial(probs, num_samples=1) | |
| with torch.no_grad(): | |
| # Initial forward pass with the prompt | |
| outputs = self.model( | |
| input_ids, | |
| use_cache=True, | |
| output_hidden_states=True | |
| ) | |
| past_key_values = outputs.past_key_values | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # We need to maintain the full sequence for repetition penalty | |
| generated_ids = input_ids | |
| # Sample the first token | |
| next_token = get_next_token(next_token_logits, generated_ids) | |
| max_new_tokens = 512 | |
| eos_token_id = self.model.config.eos_token_id | |
| for i in range(max_new_tokens): | |
| # Append generated token to sequence history | |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) | |
| # Run forward pass for the single new token | |
| outputs = self.model( | |
| next_token, | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| output_hidden_states=True | |
| ) | |
| # Update cache and get hidden state | |
| past_key_values = outputs.past_key_values | |
| current_hidden_state = outputs.hidden_states[-1][:, -1, :] # Last layer, last token | |
| finish_reason = None | |
| if next_token.item() == eos_token_id: | |
| finish_reason = 'stop' | |
| elif i == max_new_tokens - 1: | |
| finish_reason = 'length' | |
| # Yield result matching lmdeploy format | |
| yield { | |
| 'finish_reason': finish_reason, | |
| 'hidden_state': current_hidden_state | |
| } | |
| if finish_reason: | |
| break | |
| # Prepare for next iteration | |
| next_token_logits = outputs.logits[:, -1, :] | |
| next_token = get_next_token(next_token_logits, generated_ids) |