public_chat / src /backend /inference.py
Daniel Machado Pedrozo
chore: increase max_new_tokens parameter to 4096 across chat model and inference functions for improved response generation
c81f16e
"""Inference utilities with streaming support."""
from typing import Iterator, Optional, Union, List
from transformers import Pipeline, TextIteratorStreamer
from threading import Thread
from .chat import _format_chat_prompt, Message
def _build_generation_kwargs(
max_new_tokens: int,
do_sample: bool,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
**extra_kwargs
) -> dict:
"""Constrói dicionário de kwargs para geração, incluindo apenas parâmetros fornecidos."""
kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
**extra_kwargs,
}
if temperature is not None:
kwargs["temperature"] = temperature
if top_p is not None:
kwargs["top_p"] = top_p
if top_k is not None:
kwargs["top_k"] = top_k
return kwargs
def generate_streaming(
pipeline: Pipeline,
prompt: Union[str, List[Message]],
max_new_tokens: int = 4096,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
stop_sequences: Optional[list[str]] = None,
) -> Iterator[str]:
"""
Gera texto com streaming usando TextIteratorStreamer.
Args:
pipeline: Pipeline do transformers
prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
max_new_tokens: Número máximo de tokens a gerar
temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
top_k: Top-k sampling (opcional, usa padrão do modelo se None)
do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
stop_sequences: Lista de sequências para parar a geração
Yields:
Tokens gerados um por vez
"""
# Obtém o modelo e tokenizer do pipeline
model = pipeline.model
tokenizer = pipeline.tokenizer
# Formata prompt se for lista de mensagens
if isinstance(prompt, list):
formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
else:
formatted_prompt = prompt
# Tokeniza o prompt
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Cria streamer
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Configurações de geração (usa valores padrão do modelo se não especificados)
generation_kwargs = _build_generation_kwargs(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
streamer=streamer,
use_cache=True, # Usa cache de atenção para acelerar
)
generation_kwargs.update(inputs)
# Thread para geração
generation_thread = Thread(
target=model.generate,
kwargs=generation_kwargs,
)
generation_thread.start()
# Yield tokens conforme são gerados
for token in streamer:
if stop_sequences:
# Verifica se algum stop_sequence foi encontrado
for stop_seq in stop_sequences:
if stop_seq in token:
generation_thread.join(timeout=1.0)
return
yield token
generation_thread.join()
def generate_simple(
pipeline: Pipeline,
prompt: Union[str, List[Message]],
max_new_tokens: int = 4096,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
num_return_sequences: int = 1,
) -> str:
"""
Gera texto sem streaming (mais simples, útil para testes).
Args:
pipeline: Pipeline do transformers
prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
max_new_tokens: Número máximo de tokens a gerar
temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
top_k: Top-k sampling (opcional, usa padrão do modelo se None)
do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
num_return_sequences: Número de sequências a retornar
Returns:
Texto gerado
"""
# Formata prompt se for lista de mensagens
tokenizer = pipeline.tokenizer
if isinstance(prompt, list):
formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
else:
formatted_prompt = prompt
# Prepara parâmetros do pipeline (usa valores padrão do modelo se não especificados)
pipeline_kwargs = _build_generation_kwargs(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
return_full_text=False,
)
outputs = pipeline(formatted_prompt, **pipeline_kwargs)
if num_return_sequences == 1:
return outputs[0]["generated_text"]
else:
return [output["generated_text"] for output in outputs]