"""ChatModel class that encapsulates pipeline + conversation history.""" from typing import Iterator, Optional, Union, List from transformers import Pipeline from .chat import Conversation, _format_chat_prompt, Message from .inference import generate_streaming as _generate_streaming, generate_simple as _generate_simple class ChatModel: """ Encapsula modelo + histórico de conversa para facilitar uso. Exemplo: model = ChatModel(pipeline, tokenizer) model.add_user_message("Olá") response = model.generate_streaming() model.add_assistant_message(response) """ def __init__( self, pipeline: Pipeline, system_prompt: Optional[str] = None, ): """ Inicializa ChatModel. Args: pipeline: Pipeline do transformers (deve ter model e tokenizer) system_prompt: Prompt do sistema (opcional) """ self.pipeline = pipeline self.tokenizer = pipeline.tokenizer self.conversation = Conversation(system_prompt=system_prompt) @property def messages(self) -> List[Message]: """Retorna lista de mensagens do histórico.""" return self.conversation.messages @property def messages_dict(self) -> List[dict]: """Retorna mensagens como lista de dicionários (compatível com transformers).""" return self.conversation.model_dump_messages() def add_user_message(self, content: str) -> None: """Adiciona mensagem do usuário ao histórico.""" self.conversation.add_user_message(content) def add_assistant_message(self, content: str) -> None: """Adiciona mensagem do assistente ao histórico.""" self.conversation.add_assistant_message(content) def set_system_prompt(self, content: str) -> None: """Define ou atualiza o prompt do sistema.""" self.conversation.set_system_prompt(content) def clear_history(self, keep_system: bool = True) -> None: """ Limpa o histórico de conversa. Args: keep_system: Se True, mantém mensagens do sistema """ self.conversation.clear(keep_system=keep_system) def get_formatted_prompt(self, add_generation_prompt: bool = True) -> str: """ Retorna prompt formatado com histórico completo. Args: add_generation_prompt: Se True, adiciona prompt de geração Returns: String formatada pronta para o modelo """ return _format_chat_prompt( self.tokenizer, self.conversation.messages, add_generation_prompt=add_generation_prompt, ) def generate_streaming( self, max_new_tokens: int = 4096, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: bool = True, stop_sequences: Optional[list[str]] = None, ) -> Iterator[str]: """ Gera resposta com streaming usando o histórico completo. Args: max_new_tokens: Número máximo de tokens a gerar temperature: Temperatura para sampling (opcional) top_p: Nucleus sampling (opcional) top_k: Top-k sampling (opcional) do_sample: Se True, usa sampling stop_sequences: Lista de sequências para parar Yields: Tokens gerados um por vez """ return _generate_streaming( pipeline=self.pipeline, prompt=self.conversation.messages, # List[Message] funciona com _format_chat_prompt max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, stop_sequences=stop_sequences, ) def generate( self, max_new_tokens: int = 4096, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: bool = True, ) -> str: """ Gera resposta completa usando o histórico completo. Args: max_new_tokens: Número máximo de tokens a gerar temperature: Temperatura para sampling (opcional) top_p: Nucleus sampling (opcional) top_k: Top-k sampling (opcional) do_sample: Se True, usa sampling Returns: Texto gerado completo """ return _generate_simple( pipeline=self.pipeline, prompt=self.conversation.messages, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, ) def chat( self, user_message: str, max_new_tokens: int = 4096, temperature: Optional[float] = None, streaming: bool = False, ) -> Union[str, Iterator[str]]: """ Método conveniente para chat completo (adiciona mensagem + gera + adiciona resposta). Args: user_message: Mensagem do usuário max_new_tokens: Número máximo de tokens a gerar temperature: Temperatura para sampling (opcional) streaming: Se True, retorna iterator; se False, retorna string completa Returns: Resposta do modelo (string ou iterator) """ # Adiciona mensagem do usuário self.add_user_message(user_message) # Gera resposta if streaming: return self.generate_streaming( max_new_tokens=max_new_tokens, temperature=temperature, ) else: response = self.generate( max_new_tokens=max_new_tokens, temperature=temperature, ) # Adiciona resposta ao histórico self.add_assistant_message(response) return response def __repr__(self) -> str: """Representação string do modelo.""" return f"ChatModel({len(self.conversation)} messages)"