Spaces:
Sleeping
Sleeping
Daniel Machado Pedrozo commited on
Commit ·
91c131d
1
Parent(s): 3d96e49
Implement initial project structure with Dockerfile, requirements, and Streamlit app. Added model loading and inference utilities, along with chat management features. Updated entry point and added new dependencies.
Browse files- Dockerfile +2 -1
- requirements.txt +4 -1
- src/app.py +201 -0
- src/backend/__init__.py +20 -0
- src/backend/chat.py +208 -0
- src/backend/chat_model.py +188 -0
- src/backend/inference.py +162 -0
- src/backend/model_loader.py +138 -0
- src/config.py +71 -0
- src/streamlit_app.py +0 -40
Dockerfile
CHANGED
|
@@ -11,10 +11,11 @@ RUN apt-get update && apt-get install -y \
|
|
| 11 |
COPY requirements.txt ./
|
| 12 |
COPY src/ ./src/
|
| 13 |
|
|
|
|
| 14 |
RUN pip3 install -r requirements.txt
|
| 15 |
|
| 16 |
EXPOSE 8501
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/
|
|
|
|
| 11 |
COPY requirements.txt ./
|
| 12 |
COPY src/ ./src/
|
| 13 |
|
| 14 |
+
RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 15 |
RUN pip3 install -r requirements.txt
|
| 16 |
|
| 17 |
EXPOSE 8501
|
| 18 |
|
| 19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
|
| 21 |
+
ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
+
streamlit
|
| 4 |
+
dotenv
|
| 5 |
+
transformers
|
| 6 |
+
pydantic
|
src/app.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interface de chat com modelo de linguagem."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import base64
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from backend import load_model, ChatModel
|
| 7 |
+
from config import get_model_options, GATED_MODELS
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="Small LLM - Chat", layout="wide")
|
| 10 |
+
|
| 11 |
+
# Caminho da logo (relativo à raiz do projeto)
|
| 12 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 13 |
+
LOGO_PATH = PROJECT_ROOT / "positivo-logo.png"
|
| 14 |
+
|
| 15 |
+
# Header com logo e título usando HTML/CSS para melhor controle
|
| 16 |
+
with open(LOGO_PATH, "rb") as img_file:
|
| 17 |
+
img_base64 = base64.b64encode(img_file.read()).decode()
|
| 18 |
+
|
| 19 |
+
st.markdown(f"""
|
| 20 |
+
<style>
|
| 21 |
+
.logo-header {{
|
| 22 |
+
display: flex;
|
| 23 |
+
align-items: center;
|
| 24 |
+
gap: 20px;
|
| 25 |
+
margin-bottom: 0.5rem;
|
| 26 |
+
}}
|
| 27 |
+
.logo-header img {{
|
| 28 |
+
width: 90px;
|
| 29 |
+
height: 90px;
|
| 30 |
+
object-fit: contain;
|
| 31 |
+
flex-shrink: 0;
|
| 32 |
+
}}
|
| 33 |
+
</style>
|
| 34 |
+
<div class="logo-header">
|
| 35 |
+
<img src="data:image/png;base64,{img_base64}" />
|
| 36 |
+
<h1 style="margin: 0; padding: 0; display: inline-block;">Small LLM - Chat</h1>
|
| 37 |
+
</div>
|
| 38 |
+
""", unsafe_allow_html=True)
|
| 39 |
+
|
| 40 |
+
# ============================================================================
|
| 41 |
+
# FUNÇÕES AUXILIARES
|
| 42 |
+
# ============================================================================
|
| 43 |
+
|
| 44 |
+
def handle_model_load_error(model_name: str, error_msg: str):
|
| 45 |
+
"""Trata erros de carregamento de modelo, especialmente modelos gated."""
|
| 46 |
+
is_gated_error = (
|
| 47 |
+
model_name in GATED_MODELS and (
|
| 48 |
+
"401" in error_msg or
|
| 49 |
+
"gated" in error_msg.lower() or
|
| 50 |
+
"access" in error_msg.lower() or
|
| 51 |
+
"restricted" in error_msg.lower()
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if is_gated_error:
|
| 56 |
+
st.error(
|
| 57 |
+
f"⚠️ **Modelo gated detectado!**\n\n"
|
| 58 |
+
f"O modelo `{model_name}` requer autenticação.\n\n"
|
| 59 |
+
f"**No Hugging Face Spaces:**\n"
|
| 60 |
+
f"1. Vá em Settings → Repository secrets\n"
|
| 61 |
+
f"2. Adicione `HF_TOKEN` com seu token do Hugging Face\n"
|
| 62 |
+
f"3. Aceite os termos em: https://huggingface.co/{model_name}"
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
st.error(f"❌ Erro ao carregar modelo: {error_msg}")
|
| 66 |
+
|
| 67 |
+
# ============================================================================
|
| 68 |
+
# INTERFACE DE CHAT
|
| 69 |
+
# ============================================================================
|
| 70 |
+
|
| 71 |
+
# Sidebar para configurações
|
| 72 |
+
with st.sidebar:
|
| 73 |
+
st.header("⚙️ Configurações")
|
| 74 |
+
|
| 75 |
+
model_options = get_model_options()
|
| 76 |
+
selected_label = st.selectbox(
|
| 77 |
+
"Selecione um Modelo",
|
| 78 |
+
options=[opt[0] for opt in model_options],
|
| 79 |
+
index=0,
|
| 80 |
+
help="Modelos pré-selecionados para teste"
|
| 81 |
+
)
|
| 82 |
+
selected_model = next(opt[1] for opt in model_options if opt[0] == selected_label)
|
| 83 |
+
|
| 84 |
+
use_custom = st.checkbox("Usar modelo customizado")
|
| 85 |
+
|
| 86 |
+
if use_custom:
|
| 87 |
+
model_name = st.text_input(
|
| 88 |
+
"Nome do Modelo (Hugging Face)",
|
| 89 |
+
value="gpt2",
|
| 90 |
+
help="Digite o nome completo do modelo no Hugging Face"
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
model_name = selected_model
|
| 94 |
+
|
| 95 |
+
use_quantization = st.checkbox(
|
| 96 |
+
"Usar Quantização (8-bit)",
|
| 97 |
+
value=False,
|
| 98 |
+
help="Reduz uso de memória, mas pode ser mais lento"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if st.button("🔄 Carregar Modelo", type="primary"):
|
| 102 |
+
with st.spinner(f"Carregando {model_name}..."):
|
| 103 |
+
try:
|
| 104 |
+
pipeline, model_info = load_model(
|
| 105 |
+
model_name,
|
| 106 |
+
load_in_8bit=use_quantization
|
| 107 |
+
)
|
| 108 |
+
chat_model = ChatModel(pipeline)
|
| 109 |
+
st.session_state.chat_model = chat_model
|
| 110 |
+
st.session_state.model_info = model_info
|
| 111 |
+
st.session_state.model_name = model_name
|
| 112 |
+
st.success("✅ Modelo carregado!")
|
| 113 |
+
if "messages" in st.session_state:
|
| 114 |
+
del st.session_state.messages
|
| 115 |
+
except Exception as e:
|
| 116 |
+
handle_model_load_error(model_name, str(e))
|
| 117 |
+
|
| 118 |
+
if "model_info" in st.session_state:
|
| 119 |
+
st.divider()
|
| 120 |
+
st.subheader("📊 Informações do Modelo")
|
| 121 |
+
st.json(st.session_state.model_info)
|
| 122 |
+
|
| 123 |
+
if "chat_model" in st.session_state:
|
| 124 |
+
chat_model = st.session_state.chat_model
|
| 125 |
+
st.divider()
|
| 126 |
+
st.subheader("💭 Estatísticas da Conversa")
|
| 127 |
+
st.metric("Mensagens", len(chat_model.conversation))
|
| 128 |
+
|
| 129 |
+
if st.button("🗑️ Limpar Histórico", use_container_width=True):
|
| 130 |
+
chat_model.clear_history()
|
| 131 |
+
if "messages" in st.session_state:
|
| 132 |
+
del st.session_state.messages
|
| 133 |
+
st.rerun()
|
| 134 |
+
|
| 135 |
+
# Área principal - Chat
|
| 136 |
+
if "chat_model" not in st.session_state:
|
| 137 |
+
st.info("👈 Use a sidebar para carregar um modelo primeiro.")
|
| 138 |
+
st.markdown("""
|
| 139 |
+
### Modelos disponíveis:
|
| 140 |
+
|
| 141 |
+
**Google Gemma:**
|
| 142 |
+
- `google/gemma-3-4b-it` - 4 bilhões de parâmetros
|
| 143 |
+
- `google/gemma-3-1b-it` - 1 bilhão de parâmetros
|
| 144 |
+
- `google/gemma-3-270m-it` - 270 milhões de parâmetros
|
| 145 |
+
|
| 146 |
+
**Qwen:**
|
| 147 |
+
- `Qwen/Qwen3-0.6B` - 600 milhões de parâmetros
|
| 148 |
+
- `Qwen/Qwen2.5-0.5B-Instruct` - 500 milhões (instruct)
|
| 149 |
+
- `Qwen/Qwen2.5-0.5B` - 500 milhões
|
| 150 |
+
|
| 151 |
+
**Facebook:**
|
| 152 |
+
- `facebook/MobileLLM-R1-950M` - 950 milhões de parâmetros
|
| 153 |
+
""")
|
| 154 |
+
else:
|
| 155 |
+
chat_model = st.session_state.chat_model
|
| 156 |
+
|
| 157 |
+
if "messages" not in st.session_state:
|
| 158 |
+
st.session_state.messages = []
|
| 159 |
+
|
| 160 |
+
if len(chat_model.conversation.messages) != len(st.session_state.messages):
|
| 161 |
+
st.session_state.messages = [
|
| 162 |
+
{"role": msg.role, "content": msg.content}
|
| 163 |
+
for msg in chat_model.conversation.messages
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
chat_container = st.container()
|
| 167 |
+
|
| 168 |
+
with chat_container:
|
| 169 |
+
for message in st.session_state.messages:
|
| 170 |
+
role = message["role"]
|
| 171 |
+
content = message["content"]
|
| 172 |
+
|
| 173 |
+
if role == "system":
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
with st.chat_message(role):
|
| 177 |
+
st.markdown(content)
|
| 178 |
+
|
| 179 |
+
if user_input := st.chat_input("Digite sua mensagem..."):
|
| 180 |
+
chat_model.add_user_message(user_input)
|
| 181 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 182 |
+
|
| 183 |
+
with st.chat_message("user"):
|
| 184 |
+
st.markdown(user_input)
|
| 185 |
+
|
| 186 |
+
with st.chat_message("assistant"):
|
| 187 |
+
response_placeholder = st.empty()
|
| 188 |
+
full_response = ""
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
for token in chat_model.generate_streaming(max_new_tokens=512):
|
| 192 |
+
full_response += token
|
| 193 |
+
response_placeholder.markdown(full_response)
|
| 194 |
+
|
| 195 |
+
chat_model.add_assistant_message(full_response)
|
| 196 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
error_msg = f"Erro na geração: {str(e)}"
|
| 200 |
+
st.error(error_msg)
|
| 201 |
+
st.session_state.messages.append({"role": "assistant", "content": error_msg})
|
src/backend/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backend module for LLM model loading and inference."""
|
| 2 |
+
|
| 3 |
+
from .model_loader import load_model
|
| 4 |
+
from .chat import Conversation, Message
|
| 5 |
+
from .chat_model import ChatModel
|
| 6 |
+
from .inference import generate_streaming, generate_simple
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
# Model loading
|
| 10 |
+
"load_model",
|
| 11 |
+
# OOP classes (recomendado)
|
| 12 |
+
"Conversation",
|
| 13 |
+
"ChatModel",
|
| 14 |
+
# Functions (compatibilidade)
|
| 15 |
+
"generate_streaming",
|
| 16 |
+
"generate_simple",
|
| 17 |
+
# Types
|
| 18 |
+
"Message",
|
| 19 |
+
]
|
| 20 |
+
|
src/backend/chat.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chat utilities for managing conversation history with chat templates."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Literal
|
| 4 |
+
from pydantic import BaseModel, Field, field_validator
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Message(BaseModel):
|
| 9 |
+
"""
|
| 10 |
+
Mensagem de chat no formato compatível OpenAI.
|
| 11 |
+
|
| 12 |
+
Exemplo:
|
| 13 |
+
msg = Message(role="user", content="Olá!")
|
| 14 |
+
msg_dict = msg.model_dump() # {"role": "user", "content": "Olá!"}
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
role: Literal["user", "assistant", "system"] = Field(
|
| 18 |
+
...,
|
| 19 |
+
description="Role da mensagem: user, assistant ou system"
|
| 20 |
+
)
|
| 21 |
+
content: str = Field(
|
| 22 |
+
...,
|
| 23 |
+
min_length=1,
|
| 24 |
+
description="Conteúdo da mensagem"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
@field_validator("content")
|
| 28 |
+
@classmethod
|
| 29 |
+
def validate_content(cls, v: str) -> str:
|
| 30 |
+
"""Valida que o conteúdo não está vazio."""
|
| 31 |
+
if not v.strip():
|
| 32 |
+
raise ValueError("Content não pode estar vazio")
|
| 33 |
+
return v
|
| 34 |
+
|
| 35 |
+
def model_dump_dict(self) -> dict:
|
| 36 |
+
"""Retorna como dicionário (compatível com transformers)."""
|
| 37 |
+
return {"role": self.role, "content": self.content}
|
| 38 |
+
|
| 39 |
+
class Config:
|
| 40 |
+
"""Configuração do Pydantic."""
|
| 41 |
+
json_schema_extra = {
|
| 42 |
+
"example": {
|
| 43 |
+
"role": "user",
|
| 44 |
+
"content": "Olá! Como você está?"
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _format_chat_prompt(
|
| 50 |
+
tokenizer: PreTrainedTokenizer,
|
| 51 |
+
messages: List[Message],
|
| 52 |
+
add_generation_prompt: bool = True,
|
| 53 |
+
) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Formata histórico de chat usando o template do modelo (função auxiliar interna).
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
tokenizer: Tokenizer do modelo (deve ter chat_template configurado)
|
| 59 |
+
messages: Lista de mensagens (Message ou dict)
|
| 60 |
+
add_generation_prompt: Se True, adiciona prompt de geração ao final
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
String formatada pronta para ser enviada ao modelo
|
| 64 |
+
"""
|
| 65 |
+
# Converte Message para dict se necessário
|
| 66 |
+
messages_dict = [
|
| 67 |
+
msg.model_dump_dict() if isinstance(msg, Message) else msg
|
| 68 |
+
for msg in messages
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None:
|
| 72 |
+
# Fallback: concatena mensagens simplesmente
|
| 73 |
+
formatted = ""
|
| 74 |
+
for msg in messages_dict:
|
| 75 |
+
role = msg.get("role", "user")
|
| 76 |
+
content = msg.get("content", "")
|
| 77 |
+
formatted += f"{role}: {content}\n"
|
| 78 |
+
return formatted.strip()
|
| 79 |
+
|
| 80 |
+
return tokenizer.apply_chat_template(
|
| 81 |
+
messages_dict,
|
| 82 |
+
tokenize=False,
|
| 83 |
+
add_generation_prompt=add_generation_prompt,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _get_conversation_summary(messages: List[Message], max_length: int = 100) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Retorna resumo da conversa (função auxiliar interna).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
messages: Lista de mensagens
|
| 93 |
+
max_length: Comprimento máximo do resumo
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
String resumida da conversa
|
| 97 |
+
"""
|
| 98 |
+
summary_parts = []
|
| 99 |
+
for msg in messages[-5:]: # Últimas 5 mensagens
|
| 100 |
+
if isinstance(msg, Message):
|
| 101 |
+
role = msg.role
|
| 102 |
+
content = msg.content[:50]
|
| 103 |
+
else:
|
| 104 |
+
role = msg.get("role", "unknown")
|
| 105 |
+
content = msg.get("content", "")[:50]
|
| 106 |
+
summary_parts.append(f"{role}: {content}...")
|
| 107 |
+
|
| 108 |
+
summary = " | ".join(summary_parts)
|
| 109 |
+
if len(summary) > max_length:
|
| 110 |
+
return summary[:max_length] + "..."
|
| 111 |
+
return summary
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Conversation(BaseModel):
|
| 115 |
+
"""
|
| 116 |
+
Gerencia histórico de conversa de forma orientada a objetos com Pydantic.
|
| 117 |
+
|
| 118 |
+
Exemplo:
|
| 119 |
+
conv = Conversation()
|
| 120 |
+
conv.add_user_message("Olá")
|
| 121 |
+
conv.add_assistant_message("Oi! Como posso ajudar?")
|
| 122 |
+
messages = conv.messages
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
messages: List[Message] = Field(default_factory=list)
|
| 126 |
+
system_prompt: Optional[str] = Field(default=None)
|
| 127 |
+
|
| 128 |
+
def __init__(self, system_prompt: Optional[str] = None, **data):
|
| 129 |
+
"""
|
| 130 |
+
Inicializa uma nova conversa.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
system_prompt: Prompt do sistema (opcional)
|
| 134 |
+
"""
|
| 135 |
+
super().__init__(**data)
|
| 136 |
+
if system_prompt and not self.messages:
|
| 137 |
+
self.set_system_prompt(system_prompt)
|
| 138 |
+
|
| 139 |
+
def add_message(self, role: Literal["user", "assistant", "system"], content: str) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Adiciona uma mensagem ao histórico.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
role: Role da mensagem ("user", "assistant", "system")
|
| 145 |
+
content: Conteúdo da mensagem
|
| 146 |
+
"""
|
| 147 |
+
message = Message(role=role, content=content)
|
| 148 |
+
if role == "system":
|
| 149 |
+
# Mensagens do sistema sempre vão no início
|
| 150 |
+
self.messages.insert(0, message)
|
| 151 |
+
else:
|
| 152 |
+
self.messages.append(message)
|
| 153 |
+
|
| 154 |
+
def add_user_message(self, content: str) -> None:
|
| 155 |
+
"""Adiciona mensagem do usuário."""
|
| 156 |
+
self.add_message("user", content)
|
| 157 |
+
|
| 158 |
+
def add_assistant_message(self, content: str) -> None:
|
| 159 |
+
"""Adiciona mensagem do assistente."""
|
| 160 |
+
self.add_message("assistant", content)
|
| 161 |
+
|
| 162 |
+
def set_system_prompt(self, content: str) -> None:
|
| 163 |
+
"""
|
| 164 |
+
Define ou atualiza o prompt do sistema.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
content: Conteúdo do prompt do sistema
|
| 168 |
+
"""
|
| 169 |
+
# Remove mensagens do sistema existentes
|
| 170 |
+
self.messages = [msg for msg in self.messages if msg.role != "system"]
|
| 171 |
+
# Adiciona nova mensagem do sistema no início
|
| 172 |
+
self.messages.insert(0, Message(role="system", content=content))
|
| 173 |
+
|
| 174 |
+
def clear(self, keep_system: bool = True) -> None:
|
| 175 |
+
"""
|
| 176 |
+
Limpa o histórico de conversa.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
keep_system: Se True, mantém mensagens do sistema
|
| 180 |
+
"""
|
| 181 |
+
if keep_system:
|
| 182 |
+
self.messages = [msg for msg in self.messages if msg.role == "system"]
|
| 183 |
+
else:
|
| 184 |
+
self.messages = []
|
| 185 |
+
|
| 186 |
+
def get_summary(self, max_length: int = 100) -> str:
|
| 187 |
+
"""
|
| 188 |
+
Retorna resumo da conversa.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
max_length: Comprimento máximo do resumo
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
String resumida da conversa
|
| 195 |
+
"""
|
| 196 |
+
return _get_conversation_summary(self.messages, max_length)
|
| 197 |
+
|
| 198 |
+
def model_dump_messages(self) -> List[dict]:
|
| 199 |
+
"""Retorna mensagens como lista de dicionários (compatível com transformers)."""
|
| 200 |
+
return [msg.model_dump_dict() for msg in self.messages]
|
| 201 |
+
|
| 202 |
+
def __len__(self) -> int:
|
| 203 |
+
"""Retorna número de mensagens."""
|
| 204 |
+
return len(self.messages)
|
| 205 |
+
|
| 206 |
+
def __repr__(self) -> str:
|
| 207 |
+
"""Representação string da conversa."""
|
| 208 |
+
return f"Conversation({len(self.messages)} messages)"
|
src/backend/chat_model.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChatModel class that encapsulates pipeline + conversation history."""
|
| 2 |
+
|
| 3 |
+
from typing import Iterator, Optional, Union, List
|
| 4 |
+
from transformers import Pipeline
|
| 5 |
+
from .chat import Conversation, _format_chat_prompt, Message
|
| 6 |
+
from .inference import generate_streaming as _generate_streaming, generate_simple as _generate_simple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChatModel:
|
| 10 |
+
"""
|
| 11 |
+
Encapsula modelo + histórico de conversa para facilitar uso.
|
| 12 |
+
|
| 13 |
+
Exemplo:
|
| 14 |
+
model = ChatModel(pipeline, tokenizer)
|
| 15 |
+
model.add_user_message("Olá")
|
| 16 |
+
response = model.generate_streaming()
|
| 17 |
+
model.add_assistant_message(response)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
pipeline: Pipeline,
|
| 23 |
+
system_prompt: Optional[str] = None,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Inicializa ChatModel.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
pipeline: Pipeline do transformers (deve ter model e tokenizer)
|
| 30 |
+
system_prompt: Prompt do sistema (opcional)
|
| 31 |
+
"""
|
| 32 |
+
self.pipeline = pipeline
|
| 33 |
+
self.tokenizer = pipeline.tokenizer
|
| 34 |
+
self.conversation = Conversation(system_prompt=system_prompt)
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def messages(self) -> List[Message]:
|
| 38 |
+
"""Retorna lista de mensagens do histórico."""
|
| 39 |
+
return self.conversation.messages
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def messages_dict(self) -> List[dict]:
|
| 43 |
+
"""Retorna mensagens como lista de dicionários (compatível com transformers)."""
|
| 44 |
+
return self.conversation.model_dump_messages()
|
| 45 |
+
|
| 46 |
+
def add_user_message(self, content: str) -> None:
|
| 47 |
+
"""Adiciona mensagem do usuário ao histórico."""
|
| 48 |
+
self.conversation.add_user_message(content)
|
| 49 |
+
|
| 50 |
+
def add_assistant_message(self, content: str) -> None:
|
| 51 |
+
"""Adiciona mensagem do assistente ao histórico."""
|
| 52 |
+
self.conversation.add_assistant_message(content)
|
| 53 |
+
|
| 54 |
+
def set_system_prompt(self, content: str) -> None:
|
| 55 |
+
"""Define ou atualiza o prompt do sistema."""
|
| 56 |
+
self.conversation.set_system_prompt(content)
|
| 57 |
+
|
| 58 |
+
def clear_history(self, keep_system: bool = True) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Limpa o histórico de conversa.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
keep_system: Se True, mantém mensagens do sistema
|
| 64 |
+
"""
|
| 65 |
+
self.conversation.clear(keep_system=keep_system)
|
| 66 |
+
|
| 67 |
+
def get_formatted_prompt(self, add_generation_prompt: bool = True) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Retorna prompt formatado com histórico completo.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
add_generation_prompt: Se True, adiciona prompt de geração
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
String formatada pronta para o modelo
|
| 76 |
+
"""
|
| 77 |
+
return _format_chat_prompt(
|
| 78 |
+
self.tokenizer,
|
| 79 |
+
self.conversation.messages,
|
| 80 |
+
add_generation_prompt=add_generation_prompt,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def generate_streaming(
|
| 84 |
+
self,
|
| 85 |
+
max_new_tokens: int = 512,
|
| 86 |
+
temperature: Optional[float] = None,
|
| 87 |
+
top_p: Optional[float] = None,
|
| 88 |
+
top_k: Optional[int] = None,
|
| 89 |
+
do_sample: bool = True,
|
| 90 |
+
stop_sequences: Optional[list[str]] = None,
|
| 91 |
+
) -> Iterator[str]:
|
| 92 |
+
"""
|
| 93 |
+
Gera resposta com streaming usando o histórico completo.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
max_new_tokens: Número máximo de tokens a gerar
|
| 97 |
+
temperature: Temperatura para sampling (opcional)
|
| 98 |
+
top_p: Nucleus sampling (opcional)
|
| 99 |
+
top_k: Top-k sampling (opcional)
|
| 100 |
+
do_sample: Se True, usa sampling
|
| 101 |
+
stop_sequences: Lista de sequências para parar
|
| 102 |
+
|
| 103 |
+
Yields:
|
| 104 |
+
Tokens gerados um por vez
|
| 105 |
+
"""
|
| 106 |
+
return _generate_streaming(
|
| 107 |
+
pipeline=self.pipeline,
|
| 108 |
+
prompt=self.conversation.messages, # List[Message] funciona com _format_chat_prompt
|
| 109 |
+
max_new_tokens=max_new_tokens,
|
| 110 |
+
temperature=temperature,
|
| 111 |
+
top_p=top_p,
|
| 112 |
+
top_k=top_k,
|
| 113 |
+
do_sample=do_sample,
|
| 114 |
+
stop_sequences=stop_sequences,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def generate(
|
| 118 |
+
self,
|
| 119 |
+
max_new_tokens: int = 512,
|
| 120 |
+
temperature: Optional[float] = None,
|
| 121 |
+
top_p: Optional[float] = None,
|
| 122 |
+
top_k: Optional[int] = None,
|
| 123 |
+
do_sample: bool = True,
|
| 124 |
+
) -> str:
|
| 125 |
+
"""
|
| 126 |
+
Gera resposta completa usando o histórico completo.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
max_new_tokens: Número máximo de tokens a gerar
|
| 130 |
+
temperature: Temperatura para sampling (opcional)
|
| 131 |
+
top_p: Nucleus sampling (opcional)
|
| 132 |
+
top_k: Top-k sampling (opcional)
|
| 133 |
+
do_sample: Se True, usa sampling
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Texto gerado completo
|
| 137 |
+
"""
|
| 138 |
+
return _generate_simple(
|
| 139 |
+
pipeline=self.pipeline,
|
| 140 |
+
prompt=self.conversation.messages,
|
| 141 |
+
max_new_tokens=max_new_tokens,
|
| 142 |
+
temperature=temperature,
|
| 143 |
+
top_p=top_p,
|
| 144 |
+
top_k=top_k,
|
| 145 |
+
do_sample=do_sample,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def chat(
|
| 149 |
+
self,
|
| 150 |
+
user_message: str,
|
| 151 |
+
max_new_tokens: int = 512,
|
| 152 |
+
temperature: Optional[float] = None,
|
| 153 |
+
streaming: bool = False,
|
| 154 |
+
) -> Union[str, Iterator[str]]:
|
| 155 |
+
"""
|
| 156 |
+
Método conveniente para chat completo (adiciona mensagem + gera + adiciona resposta).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
user_message: Mensagem do usuário
|
| 160 |
+
max_new_tokens: Número máximo de tokens a gerar
|
| 161 |
+
temperature: Temperatura para sampling (opcional)
|
| 162 |
+
streaming: Se True, retorna iterator; se False, retorna string completa
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Resposta do modelo (string ou iterator)
|
| 166 |
+
"""
|
| 167 |
+
# Adiciona mensagem do usuário
|
| 168 |
+
self.add_user_message(user_message)
|
| 169 |
+
|
| 170 |
+
# Gera resposta
|
| 171 |
+
if streaming:
|
| 172 |
+
return self.generate_streaming(
|
| 173 |
+
max_new_tokens=max_new_tokens,
|
| 174 |
+
temperature=temperature,
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
response = self.generate(
|
| 178 |
+
max_new_tokens=max_new_tokens,
|
| 179 |
+
temperature=temperature,
|
| 180 |
+
)
|
| 181 |
+
# Adiciona resposta ao histórico
|
| 182 |
+
self.add_assistant_message(response)
|
| 183 |
+
return response
|
| 184 |
+
|
| 185 |
+
def __repr__(self) -> str:
|
| 186 |
+
"""Representação string do modelo."""
|
| 187 |
+
return f"ChatModel({len(self.conversation)} messages)"
|
| 188 |
+
|
src/backend/inference.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference utilities with streaming support."""
|
| 2 |
+
|
| 3 |
+
from typing import Iterator, Optional, Union, List
|
| 4 |
+
from transformers import Pipeline, TextIteratorStreamer
|
| 5 |
+
from threading import Thread
|
| 6 |
+
from .chat import _format_chat_prompt, Message
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _build_generation_kwargs(
|
| 10 |
+
max_new_tokens: int,
|
| 11 |
+
do_sample: bool,
|
| 12 |
+
temperature: Optional[float] = None,
|
| 13 |
+
top_p: Optional[float] = None,
|
| 14 |
+
top_k: Optional[int] = None,
|
| 15 |
+
**extra_kwargs
|
| 16 |
+
) -> dict:
|
| 17 |
+
"""Constrói dicionário de kwargs para geração, incluindo apenas parâmetros fornecidos."""
|
| 18 |
+
kwargs = {
|
| 19 |
+
"max_new_tokens": max_new_tokens,
|
| 20 |
+
"do_sample": do_sample,
|
| 21 |
+
**extra_kwargs,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
if temperature is not None:
|
| 25 |
+
kwargs["temperature"] = temperature
|
| 26 |
+
if top_p is not None:
|
| 27 |
+
kwargs["top_p"] = top_p
|
| 28 |
+
if top_k is not None:
|
| 29 |
+
kwargs["top_k"] = top_k
|
| 30 |
+
|
| 31 |
+
return kwargs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def generate_streaming(
|
| 35 |
+
pipeline: Pipeline,
|
| 36 |
+
prompt: Union[str, List[Message]],
|
| 37 |
+
max_new_tokens: int = 512,
|
| 38 |
+
temperature: Optional[float] = None,
|
| 39 |
+
top_p: Optional[float] = None,
|
| 40 |
+
top_k: Optional[int] = None,
|
| 41 |
+
do_sample: bool = True,
|
| 42 |
+
stop_sequences: Optional[list[str]] = None,
|
| 43 |
+
) -> Iterator[str]:
|
| 44 |
+
"""
|
| 45 |
+
Gera texto com streaming usando TextIteratorStreamer.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
pipeline: Pipeline do transformers
|
| 49 |
+
prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
|
| 50 |
+
max_new_tokens: Número máximo de tokens a gerar
|
| 51 |
+
temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
|
| 52 |
+
top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
|
| 53 |
+
top_k: Top-k sampling (opcional, usa padrão do modelo se None)
|
| 54 |
+
do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
|
| 55 |
+
stop_sequences: Lista de sequências para parar a geração
|
| 56 |
+
|
| 57 |
+
Yields:
|
| 58 |
+
Tokens gerados um por vez
|
| 59 |
+
"""
|
| 60 |
+
# Obtém o modelo e tokenizer do pipeline
|
| 61 |
+
model = pipeline.model
|
| 62 |
+
tokenizer = pipeline.tokenizer
|
| 63 |
+
|
| 64 |
+
# Formata prompt se for lista de mensagens
|
| 65 |
+
if isinstance(prompt, list):
|
| 66 |
+
formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
|
| 67 |
+
else:
|
| 68 |
+
formatted_prompt = prompt
|
| 69 |
+
|
| 70 |
+
# Tokeniza o prompt
|
| 71 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
| 72 |
+
|
| 73 |
+
# Cria streamer
|
| 74 |
+
streamer = TextIteratorStreamer(
|
| 75 |
+
tokenizer,
|
| 76 |
+
skip_prompt=True,
|
| 77 |
+
skip_special_tokens=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Configurações de geração (usa valores padrão do modelo se não especificados)
|
| 81 |
+
generation_kwargs = _build_generation_kwargs(
|
| 82 |
+
max_new_tokens=max_new_tokens,
|
| 83 |
+
do_sample=do_sample,
|
| 84 |
+
temperature=temperature,
|
| 85 |
+
top_p=top_p,
|
| 86 |
+
top_k=top_k,
|
| 87 |
+
streamer=streamer,
|
| 88 |
+
use_cache=True, # Usa cache de atenção para acelerar
|
| 89 |
+
)
|
| 90 |
+
generation_kwargs.update(inputs)
|
| 91 |
+
|
| 92 |
+
# Thread para geração
|
| 93 |
+
generation_thread = Thread(
|
| 94 |
+
target=model.generate,
|
| 95 |
+
kwargs=generation_kwargs,
|
| 96 |
+
)
|
| 97 |
+
generation_thread.start()
|
| 98 |
+
|
| 99 |
+
# Yield tokens conforme são gerados
|
| 100 |
+
for token in streamer:
|
| 101 |
+
if stop_sequences:
|
| 102 |
+
# Verifica se algum stop_sequence foi encontrado
|
| 103 |
+
for stop_seq in stop_sequences:
|
| 104 |
+
if stop_seq in token:
|
| 105 |
+
generation_thread.join(timeout=1.0)
|
| 106 |
+
return
|
| 107 |
+
yield token
|
| 108 |
+
|
| 109 |
+
generation_thread.join()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def generate_simple(
|
| 113 |
+
pipeline: Pipeline,
|
| 114 |
+
prompt: Union[str, List[Message]],
|
| 115 |
+
max_new_tokens: int = 512,
|
| 116 |
+
temperature: Optional[float] = None,
|
| 117 |
+
top_p: Optional[float] = None,
|
| 118 |
+
top_k: Optional[int] = None,
|
| 119 |
+
do_sample: bool = True,
|
| 120 |
+
num_return_sequences: int = 1,
|
| 121 |
+
) -> str:
|
| 122 |
+
"""
|
| 123 |
+
Gera texto sem streaming (mais simples, útil para testes).
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
pipeline: Pipeline do transformers
|
| 127 |
+
prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
|
| 128 |
+
max_new_tokens: Número máximo de tokens a gerar
|
| 129 |
+
temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
|
| 130 |
+
top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
|
| 131 |
+
top_k: Top-k sampling (opcional, usa padrão do modelo se None)
|
| 132 |
+
do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
|
| 133 |
+
num_return_sequences: Número de sequências a retornar
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Texto gerado
|
| 137 |
+
"""
|
| 138 |
+
# Formata prompt se for lista de mensagens
|
| 139 |
+
tokenizer = pipeline.tokenizer
|
| 140 |
+
if isinstance(prompt, list):
|
| 141 |
+
formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
|
| 142 |
+
else:
|
| 143 |
+
formatted_prompt = prompt
|
| 144 |
+
|
| 145 |
+
# Prepara parâmetros do pipeline (usa valores padrão do modelo se não especificados)
|
| 146 |
+
pipeline_kwargs = _build_generation_kwargs(
|
| 147 |
+
max_new_tokens=max_new_tokens,
|
| 148 |
+
do_sample=do_sample,
|
| 149 |
+
temperature=temperature,
|
| 150 |
+
top_p=top_p,
|
| 151 |
+
top_k=top_k,
|
| 152 |
+
num_return_sequences=num_return_sequences,
|
| 153 |
+
return_full_text=False,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
outputs = pipeline(formatted_prompt, **pipeline_kwargs)
|
| 157 |
+
|
| 158 |
+
if num_return_sequences == 1:
|
| 159 |
+
return outputs[0]["generated_text"]
|
| 160 |
+
else:
|
| 161 |
+
return [output["generated_text"] for output in outputs]
|
| 162 |
+
|
src/backend/model_loader.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading utilities with Streamlit caching."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional, Dict, Any, Tuple
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
pipeline,
|
| 11 |
+
Pipeline,
|
| 12 |
+
)
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# Obtém token do Hugging Face (disponível automaticamente no Spaces)
|
| 16 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
| 17 |
+
|
| 18 |
+
# Define o diretório de cache dentro do projeto
|
| 19 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 20 |
+
MODELS_CACHE_DIR = PROJECT_ROOT / "models"
|
| 21 |
+
MODELS_CACHE_DIR.mkdir(exist_ok=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@st.cache_resource
|
| 25 |
+
def load_model(
|
| 26 |
+
model_name: str,
|
| 27 |
+
device_map: Optional[str] = "auto",
|
| 28 |
+
torch_dtype: Optional[torch.dtype] = None,
|
| 29 |
+
load_in_8bit: bool = False,
|
| 30 |
+
load_in_4bit: bool = False,
|
| 31 |
+
) -> Tuple[Pipeline, Dict[str, Any]]:
|
| 32 |
+
"""
|
| 33 |
+
Carrega um modelo do Hugging Face com cache do Streamlit.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_name: Nome do modelo no Hugging Face (ex: 'microsoft/DialoGPT-medium')
|
| 37 |
+
device_map: Mapeamento de dispositivo ('auto', 'cpu', 'cuda', etc.)
|
| 38 |
+
torch_dtype: Tipo de dados do torch (ex: torch.float16)
|
| 39 |
+
load_in_8bit: Se True, carrega modelo quantizado em 8-bit
|
| 40 |
+
load_in_4bit: Se True, carrega modelo quantizado em 4-bit
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tupla contendo (pipeline, model_info)
|
| 44 |
+
"""
|
| 45 |
+
try:
|
| 46 |
+
# Detecta dispositivo disponível
|
| 47 |
+
has_cuda = torch.cuda.is_available()
|
| 48 |
+
|
| 49 |
+
# Determina o dtype padrão
|
| 50 |
+
if torch_dtype is None:
|
| 51 |
+
if has_cuda:
|
| 52 |
+
torch_dtype = torch.float16
|
| 53 |
+
else:
|
| 54 |
+
torch_dtype = torch.float32
|
| 55 |
+
|
| 56 |
+
# Ajusta device_map: se não há GPU ou device_map é "auto" sem GPU, usa None
|
| 57 |
+
if device_map == "auto" and not has_cuda:
|
| 58 |
+
device_map = None
|
| 59 |
+
elif device_map == "auto" and has_cuda:
|
| 60 |
+
device_map = "auto"
|
| 61 |
+
|
| 62 |
+
# Configurações de quantização
|
| 63 |
+
model_kwargs = {
|
| 64 |
+
"torch_dtype": torch_dtype,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Só adiciona device_map se não for None
|
| 68 |
+
if device_map is not None:
|
| 69 |
+
model_kwargs["device_map"] = device_map
|
| 70 |
+
|
| 71 |
+
if load_in_8bit or load_in_4bit:
|
| 72 |
+
try:
|
| 73 |
+
from transformers import BitsAndBytesConfig
|
| 74 |
+
|
| 75 |
+
quantization_config = BitsAndBytesConfig(
|
| 76 |
+
load_in_8bit=load_in_8bit,
|
| 77 |
+
load_in_4bit=load_in_4bit,
|
| 78 |
+
)
|
| 79 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 80 |
+
except ImportError:
|
| 81 |
+
st.warning("bitsandbytes não está instalado. Quantização desabilitada.")
|
| 82 |
+
|
| 83 |
+
# Carrega tokenizer e modelo usando cache do projeto
|
| 84 |
+
cache_dir = str(MODELS_CACHE_DIR)
|
| 85 |
+
|
| 86 |
+
# Prepara kwargs com token de autenticação se disponível
|
| 87 |
+
hf_kwargs = {"cache_dir": cache_dir}
|
| 88 |
+
if HF_TOKEN:
|
| 89 |
+
hf_kwargs["token"] = HF_TOKEN
|
| 90 |
+
|
| 91 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 92 |
+
model_name,
|
| 93 |
+
**hf_kwargs
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Adiciona pad_token se não existir
|
| 97 |
+
if tokenizer.pad_token is None:
|
| 98 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 99 |
+
|
| 100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
model_name,
|
| 102 |
+
**hf_kwargs,
|
| 103 |
+
**model_kwargs
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Move modelo para CPU se não há GPU e device_map não foi usado
|
| 107 |
+
if device_map is None and not has_cuda:
|
| 108 |
+
model = model.to("cpu")
|
| 109 |
+
|
| 110 |
+
# Cria pipeline
|
| 111 |
+
pipeline_kwargs = {
|
| 112 |
+
"model": model,
|
| 113 |
+
"tokenizer": tokenizer,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# Só adiciona device ao pipeline se não usar device_map no modelo
|
| 117 |
+
if device_map is None:
|
| 118 |
+
pipeline_kwargs["device"] = 0 if has_cuda else -1
|
| 119 |
+
else:
|
| 120 |
+
pipeline_kwargs["device_map"] = device_map
|
| 121 |
+
|
| 122 |
+
pipe = pipeline("text-generation", **pipeline_kwargs)
|
| 123 |
+
|
| 124 |
+
# Informações do modelo
|
| 125 |
+
model_info = {
|
| 126 |
+
"model_name": model_name,
|
| 127 |
+
"device": str(next(model.parameters()).device),
|
| 128 |
+
"dtype": str(torch_dtype),
|
| 129 |
+
"quantized": load_in_8bit or load_in_4bit,
|
| 130 |
+
"cache_dir": cache_dir,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
return pipe, model_info
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
st.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
|
| 137 |
+
raise
|
| 138 |
+
|
src/config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configurações do projeto."""
|
| 2 |
+
|
| 3 |
+
# Lista de modelos pré-selecionados para teste
|
| 4 |
+
PRESELECTED_MODELS = [
|
| 5 |
+
"Qwen/Qwen3-0.6B", # Modelo padrão
|
| 6 |
+
"google/gemma-3-4b-it",
|
| 7 |
+
"google/gemma-3-1b-it",
|
| 8 |
+
"google/gemma-3-270m-it",
|
| 9 |
+
"Qwen/Qwen2.5-0.5B-Instruct",
|
| 10 |
+
"Qwen/Qwen2.5-0.5B",
|
| 11 |
+
"facebook/MobileLLM-R1-950M",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
# Modelos que requerem autenticação (gated)
|
| 15 |
+
GATED_MODELS = {
|
| 16 |
+
"google/gemma-3-4b-it",
|
| 17 |
+
"google/gemma-3-1b-it",
|
| 18 |
+
"google/gemma-3-270m-it",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
# Informações sobre os modelos (para exibição)
|
| 22 |
+
MODEL_INFO = {
|
| 23 |
+
"google/gemma-3-4b-it": {
|
| 24 |
+
"name": "Gemma 3 4B IT",
|
| 25 |
+
"params": "4 bilhões",
|
| 26 |
+
"family": "Google Gemma",
|
| 27 |
+
},
|
| 28 |
+
"google/gemma-3-1b-it": {
|
| 29 |
+
"name": "Gemma 3 1B IT",
|
| 30 |
+
"params": "1 bilhão",
|
| 31 |
+
"family": "Google Gemma",
|
| 32 |
+
},
|
| 33 |
+
"google/gemma-3-270m-it": {
|
| 34 |
+
"name": "Gemma 3 270M IT",
|
| 35 |
+
"params": "270 milhões",
|
| 36 |
+
"family": "Google Gemma",
|
| 37 |
+
},
|
| 38 |
+
"Qwen/Qwen3-0.6B": {
|
| 39 |
+
"name": "Qwen3 0.6B",
|
| 40 |
+
"params": "600 milhões",
|
| 41 |
+
"family": "Qwen",
|
| 42 |
+
},
|
| 43 |
+
"Qwen/Qwen2.5-0.5B-Instruct": {
|
| 44 |
+
"name": "Qwen2.5 0.5B Instruct",
|
| 45 |
+
"params": "500 milhões",
|
| 46 |
+
"family": "Qwen",
|
| 47 |
+
},
|
| 48 |
+
"Qwen/Qwen2.5-0.5B": {
|
| 49 |
+
"name": "Qwen2.5 0.5B",
|
| 50 |
+
"params": "500 milhões",
|
| 51 |
+
"family": "Qwen",
|
| 52 |
+
},
|
| 53 |
+
"facebook/MobileLLM-R1-950M": {
|
| 54 |
+
"name": "MobileLLM R1 950M",
|
| 55 |
+
"params": "950 milhões",
|
| 56 |
+
"family": "Facebook",
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_model_label(model_id: str) -> str:
|
| 62 |
+
"""Retorna label amigável para um modelo."""
|
| 63 |
+
if model_id in MODEL_INFO:
|
| 64 |
+
info = MODEL_INFO[model_id]
|
| 65 |
+
return f"{info['name']} ({info['params']})"
|
| 66 |
+
return model_id
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_model_options() -> list[tuple[str, str]]:
|
| 70 |
+
"""Retorna lista de tuplas (label, model_id) para uso em selectbox."""
|
| 71 |
+
return [(get_model_label(model_id), model_id) for model_id in PRESELECTED_MODELS]
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|