Soprano-TTS / soprano /backends /transformers.py
ekwek's picture
Upload 17 files
5d9390a verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper
from .base import BaseModel
class TransformersModel(BaseModel):
def __init__(self,
device='cuda',
model_path=None,
**kwargs):
self.device = device
# Use local model if path provided, otherwise use HuggingFace
model_name_or_path = model_path if model_path else 'ekwek/Soprano-1.1-80M'
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
device_map=device
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model.eval()
def infer(self,
prompts,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
inputs = self.tokenizer(
prompts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=512,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
return_dict_in_generate=True,
output_hidden_states=True,
)
res = []
eos_token_id = self.model.config.eos_token_id
for i in range(len(prompts)):
seq = outputs.sequences[i]
hidden_states = []
num_output_tokens = len(outputs.hidden_states)
for j in range(num_output_tokens):
token = seq[j + seq.size(0) - num_output_tokens]
if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
last_hidden_state = torch.stack(hidden_states).squeeze()
finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
res.append({
'finish_reason': finish_reason,
'hidden_state': last_hidden_state
})
return res
def stream_infer(self,
prompt,
top_p=0.95,
temperature=0.3,
repetition_penalty=1.2):
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
input_ids = inputs['input_ids']
# Prepare Logits Processors for sampling
logits_processor = LogitsProcessorList()
if repetition_penalty != 1.0:
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
logits_warper = LogitsProcessorList()
if temperature != 1.0:
logits_warper.append(TemperatureLogitsWarper(temperature=temperature))
if top_p < 1.0:
logits_warper.append(TopPLogitsWarper(top_p=top_p))
# Helper to sample next token
def get_next_token(logits, input_seq):
scores = logits_processor(input_seq, logits)
scores = logits_warper(input_seq, scores)
probs = torch.nn.functional.softmax(scores, dim=-1)
# Sample from the distribution
return torch.multinomial(probs, num_samples=1)
with torch.no_grad():
# Initial forward pass with the prompt
outputs = self.model(
input_ids,
use_cache=True,
output_hidden_states=True
)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
# We need to maintain the full sequence for repetition penalty
generated_ids = input_ids
# Sample the first token
next_token = get_next_token(next_token_logits, generated_ids)
max_new_tokens = 512
eos_token_id = self.model.config.eos_token_id
for i in range(max_new_tokens):
# Append generated token to sequence history
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Run forward pass for the single new token
outputs = self.model(
next_token,
past_key_values=past_key_values,
use_cache=True,
output_hidden_states=True
)
# Update cache and get hidden state
past_key_values = outputs.past_key_values
current_hidden_state = outputs.hidden_states[-1][:, -1, :] # Last layer, last token
finish_reason = None
if next_token.item() == eos_token_id:
finish_reason = 'stop'
elif i == max_new_tokens - 1:
finish_reason = 'length'
# Yield result matching lmdeploy format
yield {
'finish_reason': finish_reason,
'hidden_state': current_hidden_state
}
if finish_reason:
break
# Prepare for next iteration
next_token_logits = outputs.logits[:, -1, :]
next_token = get_next_token(next_token_logits, generated_ids)