Spaces:
Sleeping
Sleeping
Daniel Machado Pedrozo
chore: increase max_new_tokens parameter to 4096 across chat model and inference functions for improved response generation
c81f16e
| """ChatModel class that encapsulates pipeline + conversation history.""" | |
| from typing import Iterator, Optional, Union, List | |
| from transformers import Pipeline | |
| from .chat import Conversation, _format_chat_prompt, Message | |
| from .inference import generate_streaming as _generate_streaming, generate_simple as _generate_simple | |
| class ChatModel: | |
| """ | |
| Encapsula modelo + histórico de conversa para facilitar uso. | |
| Exemplo: | |
| model = ChatModel(pipeline, tokenizer) | |
| model.add_user_message("Olá") | |
| response = model.generate_streaming() | |
| model.add_assistant_message(response) | |
| """ | |
| def __init__( | |
| self, | |
| pipeline: Pipeline, | |
| system_prompt: Optional[str] = None, | |
| ): | |
| """ | |
| Inicializa ChatModel. | |
| Args: | |
| pipeline: Pipeline do transformers (deve ter model e tokenizer) | |
| system_prompt: Prompt do sistema (opcional) | |
| """ | |
| self.pipeline = pipeline | |
| self.tokenizer = pipeline.tokenizer | |
| self.conversation = Conversation(system_prompt=system_prompt) | |
| def messages(self) -> List[Message]: | |
| """Retorna lista de mensagens do histórico.""" | |
| return self.conversation.messages | |
| def messages_dict(self) -> List[dict]: | |
| """Retorna mensagens como lista de dicionários (compatível com transformers).""" | |
| return self.conversation.model_dump_messages() | |
| def add_user_message(self, content: str) -> None: | |
| """Adiciona mensagem do usuário ao histórico.""" | |
| self.conversation.add_user_message(content) | |
| def add_assistant_message(self, content: str) -> None: | |
| """Adiciona mensagem do assistente ao histórico.""" | |
| self.conversation.add_assistant_message(content) | |
| def set_system_prompt(self, content: str) -> None: | |
| """Define ou atualiza o prompt do sistema.""" | |
| self.conversation.set_system_prompt(content) | |
| def clear_history(self, keep_system: bool = True) -> None: | |
| """ | |
| Limpa o histórico de conversa. | |
| Args: | |
| keep_system: Se True, mantém mensagens do sistema | |
| """ | |
| self.conversation.clear(keep_system=keep_system) | |
| def get_formatted_prompt(self, add_generation_prompt: bool = True) -> str: | |
| """ | |
| Retorna prompt formatado com histórico completo. | |
| Args: | |
| add_generation_prompt: Se True, adiciona prompt de geração | |
| Returns: | |
| String formatada pronta para o modelo | |
| """ | |
| return _format_chat_prompt( | |
| self.tokenizer, | |
| self.conversation.messages, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| def generate_streaming( | |
| self, | |
| max_new_tokens: int = 4096, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| do_sample: bool = True, | |
| stop_sequences: Optional[list[str]] = None, | |
| ) -> Iterator[str]: | |
| """ | |
| Gera resposta com streaming usando o histórico completo. | |
| Args: | |
| max_new_tokens: Número máximo de tokens a gerar | |
| temperature: Temperatura para sampling (opcional) | |
| top_p: Nucleus sampling (opcional) | |
| top_k: Top-k sampling (opcional) | |
| do_sample: Se True, usa sampling | |
| stop_sequences: Lista de sequências para parar | |
| Yields: | |
| Tokens gerados um por vez | |
| """ | |
| return _generate_streaming( | |
| pipeline=self.pipeline, | |
| prompt=self.conversation.messages, # List[Message] funciona com _format_chat_prompt | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| stop_sequences=stop_sequences, | |
| ) | |
| def generate( | |
| self, | |
| max_new_tokens: int = 4096, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| do_sample: bool = True, | |
| ) -> str: | |
| """ | |
| Gera resposta completa usando o histórico completo. | |
| Args: | |
| max_new_tokens: Número máximo de tokens a gerar | |
| temperature: Temperatura para sampling (opcional) | |
| top_p: Nucleus sampling (opcional) | |
| top_k: Top-k sampling (opcional) | |
| do_sample: Se True, usa sampling | |
| Returns: | |
| Texto gerado completo | |
| """ | |
| return _generate_simple( | |
| pipeline=self.pipeline, | |
| prompt=self.conversation.messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| ) | |
| def chat( | |
| self, | |
| user_message: str, | |
| max_new_tokens: int = 4096, | |
| temperature: Optional[float] = None, | |
| streaming: bool = False, | |
| ) -> Union[str, Iterator[str]]: | |
| """ | |
| Método conveniente para chat completo (adiciona mensagem + gera + adiciona resposta). | |
| Args: | |
| user_message: Mensagem do usuário | |
| max_new_tokens: Número máximo de tokens a gerar | |
| temperature: Temperatura para sampling (opcional) | |
| streaming: Se True, retorna iterator; se False, retorna string completa | |
| Returns: | |
| Resposta do modelo (string ou iterator) | |
| """ | |
| # Adiciona mensagem do usuário | |
| self.add_user_message(user_message) | |
| # Gera resposta | |
| if streaming: | |
| return self.generate_streaming( | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| else: | |
| response = self.generate( | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| # Adiciona resposta ao histórico | |
| self.add_assistant_message(response) | |
| return response | |
| def __repr__(self) -> str: | |
| """Representação string do modelo.""" | |
| return f"ChatModel({len(self.conversation)} messages)" | |