public_chat / src /backend /chat_model.py
Daniel Machado Pedrozo
chore: increase max_new_tokens parameter to 4096 across chat model and inference functions for improved response generation
c81f16e
"""ChatModel class that encapsulates pipeline + conversation history."""
from typing import Iterator, Optional, Union, List
from transformers import Pipeline
from .chat import Conversation, _format_chat_prompt, Message
from .inference import generate_streaming as _generate_streaming, generate_simple as _generate_simple
class ChatModel:
"""
Encapsula modelo + histórico de conversa para facilitar uso.
Exemplo:
model = ChatModel(pipeline, tokenizer)
model.add_user_message("Olá")
response = model.generate_streaming()
model.add_assistant_message(response)
"""
def __init__(
self,
pipeline: Pipeline,
system_prompt: Optional[str] = None,
):
"""
Inicializa ChatModel.
Args:
pipeline: Pipeline do transformers (deve ter model e tokenizer)
system_prompt: Prompt do sistema (opcional)
"""
self.pipeline = pipeline
self.tokenizer = pipeline.tokenizer
self.conversation = Conversation(system_prompt=system_prompt)
@property
def messages(self) -> List[Message]:
"""Retorna lista de mensagens do histórico."""
return self.conversation.messages
@property
def messages_dict(self) -> List[dict]:
"""Retorna mensagens como lista de dicionários (compatível com transformers)."""
return self.conversation.model_dump_messages()
def add_user_message(self, content: str) -> None:
"""Adiciona mensagem do usuário ao histórico."""
self.conversation.add_user_message(content)
def add_assistant_message(self, content: str) -> None:
"""Adiciona mensagem do assistente ao histórico."""
self.conversation.add_assistant_message(content)
def set_system_prompt(self, content: str) -> None:
"""Define ou atualiza o prompt do sistema."""
self.conversation.set_system_prompt(content)
def clear_history(self, keep_system: bool = True) -> None:
"""
Limpa o histórico de conversa.
Args:
keep_system: Se True, mantém mensagens do sistema
"""
self.conversation.clear(keep_system=keep_system)
def get_formatted_prompt(self, add_generation_prompt: bool = True) -> str:
"""
Retorna prompt formatado com histórico completo.
Args:
add_generation_prompt: Se True, adiciona prompt de geração
Returns:
String formatada pronta para o modelo
"""
return _format_chat_prompt(
self.tokenizer,
self.conversation.messages,
add_generation_prompt=add_generation_prompt,
)
def generate_streaming(
self,
max_new_tokens: int = 4096,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
stop_sequences: Optional[list[str]] = None,
) -> Iterator[str]:
"""
Gera resposta com streaming usando o histórico completo.
Args:
max_new_tokens: Número máximo de tokens a gerar
temperature: Temperatura para sampling (opcional)
top_p: Nucleus sampling (opcional)
top_k: Top-k sampling (opcional)
do_sample: Se True, usa sampling
stop_sequences: Lista de sequências para parar
Yields:
Tokens gerados um por vez
"""
return _generate_streaming(
pipeline=self.pipeline,
prompt=self.conversation.messages, # List[Message] funciona com _format_chat_prompt
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
stop_sequences=stop_sequences,
)
def generate(
self,
max_new_tokens: int = 4096,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
) -> str:
"""
Gera resposta completa usando o histórico completo.
Args:
max_new_tokens: Número máximo de tokens a gerar
temperature: Temperatura para sampling (opcional)
top_p: Nucleus sampling (opcional)
top_k: Top-k sampling (opcional)
do_sample: Se True, usa sampling
Returns:
Texto gerado completo
"""
return _generate_simple(
pipeline=self.pipeline,
prompt=self.conversation.messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
)
def chat(
self,
user_message: str,
max_new_tokens: int = 4096,
temperature: Optional[float] = None,
streaming: bool = False,
) -> Union[str, Iterator[str]]:
"""
Método conveniente para chat completo (adiciona mensagem + gera + adiciona resposta).
Args:
user_message: Mensagem do usuário
max_new_tokens: Número máximo de tokens a gerar
temperature: Temperatura para sampling (opcional)
streaming: Se True, retorna iterator; se False, retorna string completa
Returns:
Resposta do modelo (string ou iterator)
"""
# Adiciona mensagem do usuário
self.add_user_message(user_message)
# Gera resposta
if streaming:
return self.generate_streaming(
max_new_tokens=max_new_tokens,
temperature=temperature,
)
else:
response = self.generate(
max_new_tokens=max_new_tokens,
temperature=temperature,
)
# Adiciona resposta ao histórico
self.add_assistant_message(response)
return response
def __repr__(self) -> str:
"""Representação string do modelo."""
return f"ChatModel({len(self.conversation)} messages)"