"""Inference utilities with streaming support.""" from typing import Iterator, Optional, Union, List from transformers import Pipeline, TextIteratorStreamer from threading import Thread from .chat import _format_chat_prompt, Message def _build_generation_kwargs( max_new_tokens: int, do_sample: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, **extra_kwargs ) -> dict: """Constrói dicionário de kwargs para geração, incluindo apenas parâmetros fornecidos.""" kwargs = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, **extra_kwargs, } if temperature is not None: kwargs["temperature"] = temperature if top_p is not None: kwargs["top_p"] = top_p if top_k is not None: kwargs["top_k"] = top_k return kwargs def generate_streaming( pipeline: Pipeline, prompt: Union[str, List[Message]], max_new_tokens: int = 4096, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: bool = True, stop_sequences: Optional[list[str]] = None, ) -> Iterator[str]: """ Gera texto com streaming usando TextIteratorStreamer. Args: pipeline: Pipeline do transformers prompt: Texto de entrada (str) ou lista de mensagens (List[Message]) max_new_tokens: Número máximo de tokens a gerar temperature: Temperatura para sampling (opcional, usa padrão do modelo se None) top_p: Nucleus sampling (opcional, usa padrão do modelo se None) top_k: Top-k sampling (opcional, usa padrão do modelo se None) do_sample: Se True, usa sampling; caso contrário, usa greedy decoding stop_sequences: Lista de sequências para parar a geração Yields: Tokens gerados um por vez """ # Obtém o modelo e tokenizer do pipeline model = pipeline.model tokenizer = pipeline.tokenizer # Formata prompt se for lista de mensagens if isinstance(prompt, list): formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True) else: formatted_prompt = prompt # Tokeniza o prompt inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Cria streamer streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Configurações de geração (usa valores padrão do modelo se não especificados) generation_kwargs = _build_generation_kwargs( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, streamer=streamer, use_cache=True, # Usa cache de atenção para acelerar ) generation_kwargs.update(inputs) # Thread para geração generation_thread = Thread( target=model.generate, kwargs=generation_kwargs, ) generation_thread.start() # Yield tokens conforme são gerados for token in streamer: if stop_sequences: # Verifica se algum stop_sequence foi encontrado for stop_seq in stop_sequences: if stop_seq in token: generation_thread.join(timeout=1.0) return yield token generation_thread.join() def generate_simple( pipeline: Pipeline, prompt: Union[str, List[Message]], max_new_tokens: int = 4096, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: bool = True, num_return_sequences: int = 1, ) -> str: """ Gera texto sem streaming (mais simples, útil para testes). Args: pipeline: Pipeline do transformers prompt: Texto de entrada (str) ou lista de mensagens (List[Message]) max_new_tokens: Número máximo de tokens a gerar temperature: Temperatura para sampling (opcional, usa padrão do modelo se None) top_p: Nucleus sampling (opcional, usa padrão do modelo se None) top_k: Top-k sampling (opcional, usa padrão do modelo se None) do_sample: Se True, usa sampling; caso contrário, usa greedy decoding num_return_sequences: Número de sequências a retornar Returns: Texto gerado """ # Formata prompt se for lista de mensagens tokenizer = pipeline.tokenizer if isinstance(prompt, list): formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True) else: formatted_prompt = prompt # Prepara parâmetros do pipeline (usa valores padrão do modelo se não especificados) pipeline_kwargs = _build_generation_kwargs( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, return_full_text=False, ) outputs = pipeline(formatted_prompt, **pipeline_kwargs) if num_return_sequences == 1: return outputs[0]["generated_text"] else: return [output["generated_text"] for output in outputs]