Spaces:
Sleeping
Sleeping
Daniel Machado Pedrozo
chore: increase max_new_tokens parameter to 4096 across chat model and inference functions for improved response generation
c81f16e | """Inference utilities with streaming support.""" | |
| from typing import Iterator, Optional, Union, List | |
| from transformers import Pipeline, TextIteratorStreamer | |
| from threading import Thread | |
| from .chat import _format_chat_prompt, Message | |
| def _build_generation_kwargs( | |
| max_new_tokens: int, | |
| do_sample: bool, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| **extra_kwargs | |
| ) -> dict: | |
| """Constrói dicionário de kwargs para geração, incluindo apenas parâmetros fornecidos.""" | |
| kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| **extra_kwargs, | |
| } | |
| if temperature is not None: | |
| kwargs["temperature"] = temperature | |
| if top_p is not None: | |
| kwargs["top_p"] = top_p | |
| if top_k is not None: | |
| kwargs["top_k"] = top_k | |
| return kwargs | |
| def generate_streaming( | |
| pipeline: Pipeline, | |
| prompt: Union[str, List[Message]], | |
| max_new_tokens: int = 4096, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| do_sample: bool = True, | |
| stop_sequences: Optional[list[str]] = None, | |
| ) -> Iterator[str]: | |
| """ | |
| Gera texto com streaming usando TextIteratorStreamer. | |
| Args: | |
| pipeline: Pipeline do transformers | |
| prompt: Texto de entrada (str) ou lista de mensagens (List[Message]) | |
| max_new_tokens: Número máximo de tokens a gerar | |
| temperature: Temperatura para sampling (opcional, usa padrão do modelo se None) | |
| top_p: Nucleus sampling (opcional, usa padrão do modelo se None) | |
| top_k: Top-k sampling (opcional, usa padrão do modelo se None) | |
| do_sample: Se True, usa sampling; caso contrário, usa greedy decoding | |
| stop_sequences: Lista de sequências para parar a geração | |
| Yields: | |
| Tokens gerados um por vez | |
| """ | |
| # Obtém o modelo e tokenizer do pipeline | |
| model = pipeline.model | |
| tokenizer = pipeline.tokenizer | |
| # Formata prompt se for lista de mensagens | |
| if isinstance(prompt, list): | |
| formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True) | |
| else: | |
| formatted_prompt = prompt | |
| # Tokeniza o prompt | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| # Cria streamer | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| # Configurações de geração (usa valores padrão do modelo se não especificados) | |
| generation_kwargs = _build_generation_kwargs( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| streamer=streamer, | |
| use_cache=True, # Usa cache de atenção para acelerar | |
| ) | |
| generation_kwargs.update(inputs) | |
| # Thread para geração | |
| generation_thread = Thread( | |
| target=model.generate, | |
| kwargs=generation_kwargs, | |
| ) | |
| generation_thread.start() | |
| # Yield tokens conforme são gerados | |
| for token in streamer: | |
| if stop_sequences: | |
| # Verifica se algum stop_sequence foi encontrado | |
| for stop_seq in stop_sequences: | |
| if stop_seq in token: | |
| generation_thread.join(timeout=1.0) | |
| return | |
| yield token | |
| generation_thread.join() | |
| def generate_simple( | |
| pipeline: Pipeline, | |
| prompt: Union[str, List[Message]], | |
| max_new_tokens: int = 4096, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| do_sample: bool = True, | |
| num_return_sequences: int = 1, | |
| ) -> str: | |
| """ | |
| Gera texto sem streaming (mais simples, útil para testes). | |
| Args: | |
| pipeline: Pipeline do transformers | |
| prompt: Texto de entrada (str) ou lista de mensagens (List[Message]) | |
| max_new_tokens: Número máximo de tokens a gerar | |
| temperature: Temperatura para sampling (opcional, usa padrão do modelo se None) | |
| top_p: Nucleus sampling (opcional, usa padrão do modelo se None) | |
| top_k: Top-k sampling (opcional, usa padrão do modelo se None) | |
| do_sample: Se True, usa sampling; caso contrário, usa greedy decoding | |
| num_return_sequences: Número de sequências a retornar | |
| Returns: | |
| Texto gerado | |
| """ | |
| # Formata prompt se for lista de mensagens | |
| tokenizer = pipeline.tokenizer | |
| if isinstance(prompt, list): | |
| formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True) | |
| else: | |
| formatted_prompt = prompt | |
| # Prepara parâmetros do pipeline (usa valores padrão do modelo se não especificados) | |
| pipeline_kwargs = _build_generation_kwargs( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| num_return_sequences=num_return_sequences, | |
| return_full_text=False, | |
| ) | |
| outputs = pipeline(formatted_prompt, **pipeline_kwargs) | |
| if num_return_sequences == 1: | |
| return outputs[0]["generated_text"] | |
| else: | |
| return [output["generated_text"] for output in outputs] | |