Daniel Machado Pedrozo commited on
Commit
91c131d
·
1 Parent(s): 3d96e49

Implement initial project structure with Dockerfile, requirements, and Streamlit app. Added model loading and inference utilities, along with chat management features. Updated entry point and added new dependencies.

Browse files
Dockerfile CHANGED
@@ -11,10 +11,11 @@ RUN apt-get update && apt-get install -y \
11
  COPY requirements.txt ./
12
  COPY src/ ./src/
13
 
 
14
  RUN pip3 install -r requirements.txt
15
 
16
  EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
11
  COPY requirements.txt ./
12
  COPY src/ ./src/
13
 
14
+ RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
15
  RUN pip3 install -r requirements.txt
16
 
17
  EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ dotenv
5
+ transformers
6
+ pydantic
src/app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interface de chat com modelo de linguagem."""
2
+
3
+ import streamlit as st
4
+ import base64
5
+ from pathlib import Path
6
+ from backend import load_model, ChatModel
7
+ from config import get_model_options, GATED_MODELS
8
+
9
+ st.set_page_config(page_title="Small LLM - Chat", layout="wide")
10
+
11
+ # Caminho da logo (relativo à raiz do projeto)
12
+ PROJECT_ROOT = Path(__file__).parent.parent
13
+ LOGO_PATH = PROJECT_ROOT / "positivo-logo.png"
14
+
15
+ # Header com logo e título usando HTML/CSS para melhor controle
16
+ with open(LOGO_PATH, "rb") as img_file:
17
+ img_base64 = base64.b64encode(img_file.read()).decode()
18
+
19
+ st.markdown(f"""
20
+ <style>
21
+ .logo-header {{
22
+ display: flex;
23
+ align-items: center;
24
+ gap: 20px;
25
+ margin-bottom: 0.5rem;
26
+ }}
27
+ .logo-header img {{
28
+ width: 90px;
29
+ height: 90px;
30
+ object-fit: contain;
31
+ flex-shrink: 0;
32
+ }}
33
+ </style>
34
+ <div class="logo-header">
35
+ <img src="data:image/png;base64,{img_base64}" />
36
+ <h1 style="margin: 0; padding: 0; display: inline-block;">Small LLM - Chat</h1>
37
+ </div>
38
+ """, unsafe_allow_html=True)
39
+
40
+ # ============================================================================
41
+ # FUNÇÕES AUXILIARES
42
+ # ============================================================================
43
+
44
+ def handle_model_load_error(model_name: str, error_msg: str):
45
+ """Trata erros de carregamento de modelo, especialmente modelos gated."""
46
+ is_gated_error = (
47
+ model_name in GATED_MODELS and (
48
+ "401" in error_msg or
49
+ "gated" in error_msg.lower() or
50
+ "access" in error_msg.lower() or
51
+ "restricted" in error_msg.lower()
52
+ )
53
+ )
54
+
55
+ if is_gated_error:
56
+ st.error(
57
+ f"⚠️ **Modelo gated detectado!**\n\n"
58
+ f"O modelo `{model_name}` requer autenticação.\n\n"
59
+ f"**No Hugging Face Spaces:**\n"
60
+ f"1. Vá em Settings → Repository secrets\n"
61
+ f"2. Adicione `HF_TOKEN` com seu token do Hugging Face\n"
62
+ f"3. Aceite os termos em: https://huggingface.co/{model_name}"
63
+ )
64
+ else:
65
+ st.error(f"❌ Erro ao carregar modelo: {error_msg}")
66
+
67
+ # ============================================================================
68
+ # INTERFACE DE CHAT
69
+ # ============================================================================
70
+
71
+ # Sidebar para configurações
72
+ with st.sidebar:
73
+ st.header("⚙️ Configurações")
74
+
75
+ model_options = get_model_options()
76
+ selected_label = st.selectbox(
77
+ "Selecione um Modelo",
78
+ options=[opt[0] for opt in model_options],
79
+ index=0,
80
+ help="Modelos pré-selecionados para teste"
81
+ )
82
+ selected_model = next(opt[1] for opt in model_options if opt[0] == selected_label)
83
+
84
+ use_custom = st.checkbox("Usar modelo customizado")
85
+
86
+ if use_custom:
87
+ model_name = st.text_input(
88
+ "Nome do Modelo (Hugging Face)",
89
+ value="gpt2",
90
+ help="Digite o nome completo do modelo no Hugging Face"
91
+ )
92
+ else:
93
+ model_name = selected_model
94
+
95
+ use_quantization = st.checkbox(
96
+ "Usar Quantização (8-bit)",
97
+ value=False,
98
+ help="Reduz uso de memória, mas pode ser mais lento"
99
+ )
100
+
101
+ if st.button("🔄 Carregar Modelo", type="primary"):
102
+ with st.spinner(f"Carregando {model_name}..."):
103
+ try:
104
+ pipeline, model_info = load_model(
105
+ model_name,
106
+ load_in_8bit=use_quantization
107
+ )
108
+ chat_model = ChatModel(pipeline)
109
+ st.session_state.chat_model = chat_model
110
+ st.session_state.model_info = model_info
111
+ st.session_state.model_name = model_name
112
+ st.success("✅ Modelo carregado!")
113
+ if "messages" in st.session_state:
114
+ del st.session_state.messages
115
+ except Exception as e:
116
+ handle_model_load_error(model_name, str(e))
117
+
118
+ if "model_info" in st.session_state:
119
+ st.divider()
120
+ st.subheader("📊 Informações do Modelo")
121
+ st.json(st.session_state.model_info)
122
+
123
+ if "chat_model" in st.session_state:
124
+ chat_model = st.session_state.chat_model
125
+ st.divider()
126
+ st.subheader("💭 Estatísticas da Conversa")
127
+ st.metric("Mensagens", len(chat_model.conversation))
128
+
129
+ if st.button("🗑️ Limpar Histórico", use_container_width=True):
130
+ chat_model.clear_history()
131
+ if "messages" in st.session_state:
132
+ del st.session_state.messages
133
+ st.rerun()
134
+
135
+ # Área principal - Chat
136
+ if "chat_model" not in st.session_state:
137
+ st.info("👈 Use a sidebar para carregar um modelo primeiro.")
138
+ st.markdown("""
139
+ ### Modelos disponíveis:
140
+
141
+ **Google Gemma:**
142
+ - `google/gemma-3-4b-it` - 4 bilhões de parâmetros
143
+ - `google/gemma-3-1b-it` - 1 bilhão de parâmetros
144
+ - `google/gemma-3-270m-it` - 270 milhões de parâmetros
145
+
146
+ **Qwen:**
147
+ - `Qwen/Qwen3-0.6B` - 600 milhões de parâmetros
148
+ - `Qwen/Qwen2.5-0.5B-Instruct` - 500 milhões (instruct)
149
+ - `Qwen/Qwen2.5-0.5B` - 500 milhões
150
+
151
+ **Facebook:**
152
+ - `facebook/MobileLLM-R1-950M` - 950 milhões de parâmetros
153
+ """)
154
+ else:
155
+ chat_model = st.session_state.chat_model
156
+
157
+ if "messages" not in st.session_state:
158
+ st.session_state.messages = []
159
+
160
+ if len(chat_model.conversation.messages) != len(st.session_state.messages):
161
+ st.session_state.messages = [
162
+ {"role": msg.role, "content": msg.content}
163
+ for msg in chat_model.conversation.messages
164
+ ]
165
+
166
+ chat_container = st.container()
167
+
168
+ with chat_container:
169
+ for message in st.session_state.messages:
170
+ role = message["role"]
171
+ content = message["content"]
172
+
173
+ if role == "system":
174
+ continue
175
+
176
+ with st.chat_message(role):
177
+ st.markdown(content)
178
+
179
+ if user_input := st.chat_input("Digite sua mensagem..."):
180
+ chat_model.add_user_message(user_input)
181
+ st.session_state.messages.append({"role": "user", "content": user_input})
182
+
183
+ with st.chat_message("user"):
184
+ st.markdown(user_input)
185
+
186
+ with st.chat_message("assistant"):
187
+ response_placeholder = st.empty()
188
+ full_response = ""
189
+
190
+ try:
191
+ for token in chat_model.generate_streaming(max_new_tokens=512):
192
+ full_response += token
193
+ response_placeholder.markdown(full_response)
194
+
195
+ chat_model.add_assistant_message(full_response)
196
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
197
+
198
+ except Exception as e:
199
+ error_msg = f"Erro na geração: {str(e)}"
200
+ st.error(error_msg)
201
+ st.session_state.messages.append({"role": "assistant", "content": error_msg})
src/backend/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend module for LLM model loading and inference."""
2
+
3
+ from .model_loader import load_model
4
+ from .chat import Conversation, Message
5
+ from .chat_model import ChatModel
6
+ from .inference import generate_streaming, generate_simple
7
+
8
+ __all__ = [
9
+ # Model loading
10
+ "load_model",
11
+ # OOP classes (recomendado)
12
+ "Conversation",
13
+ "ChatModel",
14
+ # Functions (compatibilidade)
15
+ "generate_streaming",
16
+ "generate_simple",
17
+ # Types
18
+ "Message",
19
+ ]
20
+
src/backend/chat.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat utilities for managing conversation history with chat templates."""
2
+
3
+ from typing import List, Optional, Literal
4
+ from pydantic import BaseModel, Field, field_validator
5
+ from transformers import PreTrainedTokenizer
6
+
7
+
8
+ class Message(BaseModel):
9
+ """
10
+ Mensagem de chat no formato compatível OpenAI.
11
+
12
+ Exemplo:
13
+ msg = Message(role="user", content="Olá!")
14
+ msg_dict = msg.model_dump() # {"role": "user", "content": "Olá!"}
15
+ """
16
+
17
+ role: Literal["user", "assistant", "system"] = Field(
18
+ ...,
19
+ description="Role da mensagem: user, assistant ou system"
20
+ )
21
+ content: str = Field(
22
+ ...,
23
+ min_length=1,
24
+ description="Conteúdo da mensagem"
25
+ )
26
+
27
+ @field_validator("content")
28
+ @classmethod
29
+ def validate_content(cls, v: str) -> str:
30
+ """Valida que o conteúdo não está vazio."""
31
+ if not v.strip():
32
+ raise ValueError("Content não pode estar vazio")
33
+ return v
34
+
35
+ def model_dump_dict(self) -> dict:
36
+ """Retorna como dicionário (compatível com transformers)."""
37
+ return {"role": self.role, "content": self.content}
38
+
39
+ class Config:
40
+ """Configuração do Pydantic."""
41
+ json_schema_extra = {
42
+ "example": {
43
+ "role": "user",
44
+ "content": "Olá! Como você está?"
45
+ }
46
+ }
47
+
48
+
49
+ def _format_chat_prompt(
50
+ tokenizer: PreTrainedTokenizer,
51
+ messages: List[Message],
52
+ add_generation_prompt: bool = True,
53
+ ) -> str:
54
+ """
55
+ Formata histórico de chat usando o template do modelo (função auxiliar interna).
56
+
57
+ Args:
58
+ tokenizer: Tokenizer do modelo (deve ter chat_template configurado)
59
+ messages: Lista de mensagens (Message ou dict)
60
+ add_generation_prompt: Se True, adiciona prompt de geração ao final
61
+
62
+ Returns:
63
+ String formatada pronta para ser enviada ao modelo
64
+ """
65
+ # Converte Message para dict se necessário
66
+ messages_dict = [
67
+ msg.model_dump_dict() if isinstance(msg, Message) else msg
68
+ for msg in messages
69
+ ]
70
+
71
+ if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None:
72
+ # Fallback: concatena mensagens simplesmente
73
+ formatted = ""
74
+ for msg in messages_dict:
75
+ role = msg.get("role", "user")
76
+ content = msg.get("content", "")
77
+ formatted += f"{role}: {content}\n"
78
+ return formatted.strip()
79
+
80
+ return tokenizer.apply_chat_template(
81
+ messages_dict,
82
+ tokenize=False,
83
+ add_generation_prompt=add_generation_prompt,
84
+ )
85
+
86
+
87
+ def _get_conversation_summary(messages: List[Message], max_length: int = 100) -> str:
88
+ """
89
+ Retorna resumo da conversa (função auxiliar interna).
90
+
91
+ Args:
92
+ messages: Lista de mensagens
93
+ max_length: Comprimento máximo do resumo
94
+
95
+ Returns:
96
+ String resumida da conversa
97
+ """
98
+ summary_parts = []
99
+ for msg in messages[-5:]: # Últimas 5 mensagens
100
+ if isinstance(msg, Message):
101
+ role = msg.role
102
+ content = msg.content[:50]
103
+ else:
104
+ role = msg.get("role", "unknown")
105
+ content = msg.get("content", "")[:50]
106
+ summary_parts.append(f"{role}: {content}...")
107
+
108
+ summary = " | ".join(summary_parts)
109
+ if len(summary) > max_length:
110
+ return summary[:max_length] + "..."
111
+ return summary
112
+
113
+
114
+ class Conversation(BaseModel):
115
+ """
116
+ Gerencia histórico de conversa de forma orientada a objetos com Pydantic.
117
+
118
+ Exemplo:
119
+ conv = Conversation()
120
+ conv.add_user_message("Olá")
121
+ conv.add_assistant_message("Oi! Como posso ajudar?")
122
+ messages = conv.messages
123
+ """
124
+
125
+ messages: List[Message] = Field(default_factory=list)
126
+ system_prompt: Optional[str] = Field(default=None)
127
+
128
+ def __init__(self, system_prompt: Optional[str] = None, **data):
129
+ """
130
+ Inicializa uma nova conversa.
131
+
132
+ Args:
133
+ system_prompt: Prompt do sistema (opcional)
134
+ """
135
+ super().__init__(**data)
136
+ if system_prompt and not self.messages:
137
+ self.set_system_prompt(system_prompt)
138
+
139
+ def add_message(self, role: Literal["user", "assistant", "system"], content: str) -> None:
140
+ """
141
+ Adiciona uma mensagem ao histórico.
142
+
143
+ Args:
144
+ role: Role da mensagem ("user", "assistant", "system")
145
+ content: Conteúdo da mensagem
146
+ """
147
+ message = Message(role=role, content=content)
148
+ if role == "system":
149
+ # Mensagens do sistema sempre vão no início
150
+ self.messages.insert(0, message)
151
+ else:
152
+ self.messages.append(message)
153
+
154
+ def add_user_message(self, content: str) -> None:
155
+ """Adiciona mensagem do usuário."""
156
+ self.add_message("user", content)
157
+
158
+ def add_assistant_message(self, content: str) -> None:
159
+ """Adiciona mensagem do assistente."""
160
+ self.add_message("assistant", content)
161
+
162
+ def set_system_prompt(self, content: str) -> None:
163
+ """
164
+ Define ou atualiza o prompt do sistema.
165
+
166
+ Args:
167
+ content: Conteúdo do prompt do sistema
168
+ """
169
+ # Remove mensagens do sistema existentes
170
+ self.messages = [msg for msg in self.messages if msg.role != "system"]
171
+ # Adiciona nova mensagem do sistema no início
172
+ self.messages.insert(0, Message(role="system", content=content))
173
+
174
+ def clear(self, keep_system: bool = True) -> None:
175
+ """
176
+ Limpa o histórico de conversa.
177
+
178
+ Args:
179
+ keep_system: Se True, mantém mensagens do sistema
180
+ """
181
+ if keep_system:
182
+ self.messages = [msg for msg in self.messages if msg.role == "system"]
183
+ else:
184
+ self.messages = []
185
+
186
+ def get_summary(self, max_length: int = 100) -> str:
187
+ """
188
+ Retorna resumo da conversa.
189
+
190
+ Args:
191
+ max_length: Comprimento máximo do resumo
192
+
193
+ Returns:
194
+ String resumida da conversa
195
+ """
196
+ return _get_conversation_summary(self.messages, max_length)
197
+
198
+ def model_dump_messages(self) -> List[dict]:
199
+ """Retorna mensagens como lista de dicionários (compatível com transformers)."""
200
+ return [msg.model_dump_dict() for msg in self.messages]
201
+
202
+ def __len__(self) -> int:
203
+ """Retorna número de mensagens."""
204
+ return len(self.messages)
205
+
206
+ def __repr__(self) -> str:
207
+ """Representação string da conversa."""
208
+ return f"Conversation({len(self.messages)} messages)"
src/backend/chat_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChatModel class that encapsulates pipeline + conversation history."""
2
+
3
+ from typing import Iterator, Optional, Union, List
4
+ from transformers import Pipeline
5
+ from .chat import Conversation, _format_chat_prompt, Message
6
+ from .inference import generate_streaming as _generate_streaming, generate_simple as _generate_simple
7
+
8
+
9
+ class ChatModel:
10
+ """
11
+ Encapsula modelo + histórico de conversa para facilitar uso.
12
+
13
+ Exemplo:
14
+ model = ChatModel(pipeline, tokenizer)
15
+ model.add_user_message("Olá")
16
+ response = model.generate_streaming()
17
+ model.add_assistant_message(response)
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ pipeline: Pipeline,
23
+ system_prompt: Optional[str] = None,
24
+ ):
25
+ """
26
+ Inicializa ChatModel.
27
+
28
+ Args:
29
+ pipeline: Pipeline do transformers (deve ter model e tokenizer)
30
+ system_prompt: Prompt do sistema (opcional)
31
+ """
32
+ self.pipeline = pipeline
33
+ self.tokenizer = pipeline.tokenizer
34
+ self.conversation = Conversation(system_prompt=system_prompt)
35
+
36
+ @property
37
+ def messages(self) -> List[Message]:
38
+ """Retorna lista de mensagens do histórico."""
39
+ return self.conversation.messages
40
+
41
+ @property
42
+ def messages_dict(self) -> List[dict]:
43
+ """Retorna mensagens como lista de dicionários (compatível com transformers)."""
44
+ return self.conversation.model_dump_messages()
45
+
46
+ def add_user_message(self, content: str) -> None:
47
+ """Adiciona mensagem do usuário ao histórico."""
48
+ self.conversation.add_user_message(content)
49
+
50
+ def add_assistant_message(self, content: str) -> None:
51
+ """Adiciona mensagem do assistente ao histórico."""
52
+ self.conversation.add_assistant_message(content)
53
+
54
+ def set_system_prompt(self, content: str) -> None:
55
+ """Define ou atualiza o prompt do sistema."""
56
+ self.conversation.set_system_prompt(content)
57
+
58
+ def clear_history(self, keep_system: bool = True) -> None:
59
+ """
60
+ Limpa o histórico de conversa.
61
+
62
+ Args:
63
+ keep_system: Se True, mantém mensagens do sistema
64
+ """
65
+ self.conversation.clear(keep_system=keep_system)
66
+
67
+ def get_formatted_prompt(self, add_generation_prompt: bool = True) -> str:
68
+ """
69
+ Retorna prompt formatado com histórico completo.
70
+
71
+ Args:
72
+ add_generation_prompt: Se True, adiciona prompt de geração
73
+
74
+ Returns:
75
+ String formatada pronta para o modelo
76
+ """
77
+ return _format_chat_prompt(
78
+ self.tokenizer,
79
+ self.conversation.messages,
80
+ add_generation_prompt=add_generation_prompt,
81
+ )
82
+
83
+ def generate_streaming(
84
+ self,
85
+ max_new_tokens: int = 512,
86
+ temperature: Optional[float] = None,
87
+ top_p: Optional[float] = None,
88
+ top_k: Optional[int] = None,
89
+ do_sample: bool = True,
90
+ stop_sequences: Optional[list[str]] = None,
91
+ ) -> Iterator[str]:
92
+ """
93
+ Gera resposta com streaming usando o histórico completo.
94
+
95
+ Args:
96
+ max_new_tokens: Número máximo de tokens a gerar
97
+ temperature: Temperatura para sampling (opcional)
98
+ top_p: Nucleus sampling (opcional)
99
+ top_k: Top-k sampling (opcional)
100
+ do_sample: Se True, usa sampling
101
+ stop_sequences: Lista de sequências para parar
102
+
103
+ Yields:
104
+ Tokens gerados um por vez
105
+ """
106
+ return _generate_streaming(
107
+ pipeline=self.pipeline,
108
+ prompt=self.conversation.messages, # List[Message] funciona com _format_chat_prompt
109
+ max_new_tokens=max_new_tokens,
110
+ temperature=temperature,
111
+ top_p=top_p,
112
+ top_k=top_k,
113
+ do_sample=do_sample,
114
+ stop_sequences=stop_sequences,
115
+ )
116
+
117
+ def generate(
118
+ self,
119
+ max_new_tokens: int = 512,
120
+ temperature: Optional[float] = None,
121
+ top_p: Optional[float] = None,
122
+ top_k: Optional[int] = None,
123
+ do_sample: bool = True,
124
+ ) -> str:
125
+ """
126
+ Gera resposta completa usando o histórico completo.
127
+
128
+ Args:
129
+ max_new_tokens: Número máximo de tokens a gerar
130
+ temperature: Temperatura para sampling (opcional)
131
+ top_p: Nucleus sampling (opcional)
132
+ top_k: Top-k sampling (opcional)
133
+ do_sample: Se True, usa sampling
134
+
135
+ Returns:
136
+ Texto gerado completo
137
+ """
138
+ return _generate_simple(
139
+ pipeline=self.pipeline,
140
+ prompt=self.conversation.messages,
141
+ max_new_tokens=max_new_tokens,
142
+ temperature=temperature,
143
+ top_p=top_p,
144
+ top_k=top_k,
145
+ do_sample=do_sample,
146
+ )
147
+
148
+ def chat(
149
+ self,
150
+ user_message: str,
151
+ max_new_tokens: int = 512,
152
+ temperature: Optional[float] = None,
153
+ streaming: bool = False,
154
+ ) -> Union[str, Iterator[str]]:
155
+ """
156
+ Método conveniente para chat completo (adiciona mensagem + gera + adiciona resposta).
157
+
158
+ Args:
159
+ user_message: Mensagem do usuário
160
+ max_new_tokens: Número máximo de tokens a gerar
161
+ temperature: Temperatura para sampling (opcional)
162
+ streaming: Se True, retorna iterator; se False, retorna string completa
163
+
164
+ Returns:
165
+ Resposta do modelo (string ou iterator)
166
+ """
167
+ # Adiciona mensagem do usuário
168
+ self.add_user_message(user_message)
169
+
170
+ # Gera resposta
171
+ if streaming:
172
+ return self.generate_streaming(
173
+ max_new_tokens=max_new_tokens,
174
+ temperature=temperature,
175
+ )
176
+ else:
177
+ response = self.generate(
178
+ max_new_tokens=max_new_tokens,
179
+ temperature=temperature,
180
+ )
181
+ # Adiciona resposta ao histórico
182
+ self.add_assistant_message(response)
183
+ return response
184
+
185
+ def __repr__(self) -> str:
186
+ """Representação string do modelo."""
187
+ return f"ChatModel({len(self.conversation)} messages)"
188
+
src/backend/inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities with streaming support."""
2
+
3
+ from typing import Iterator, Optional, Union, List
4
+ from transformers import Pipeline, TextIteratorStreamer
5
+ from threading import Thread
6
+ from .chat import _format_chat_prompt, Message
7
+
8
+
9
+ def _build_generation_kwargs(
10
+ max_new_tokens: int,
11
+ do_sample: bool,
12
+ temperature: Optional[float] = None,
13
+ top_p: Optional[float] = None,
14
+ top_k: Optional[int] = None,
15
+ **extra_kwargs
16
+ ) -> dict:
17
+ """Constrói dicionário de kwargs para geração, incluindo apenas parâmetros fornecidos."""
18
+ kwargs = {
19
+ "max_new_tokens": max_new_tokens,
20
+ "do_sample": do_sample,
21
+ **extra_kwargs,
22
+ }
23
+
24
+ if temperature is not None:
25
+ kwargs["temperature"] = temperature
26
+ if top_p is not None:
27
+ kwargs["top_p"] = top_p
28
+ if top_k is not None:
29
+ kwargs["top_k"] = top_k
30
+
31
+ return kwargs
32
+
33
+
34
+ def generate_streaming(
35
+ pipeline: Pipeline,
36
+ prompt: Union[str, List[Message]],
37
+ max_new_tokens: int = 512,
38
+ temperature: Optional[float] = None,
39
+ top_p: Optional[float] = None,
40
+ top_k: Optional[int] = None,
41
+ do_sample: bool = True,
42
+ stop_sequences: Optional[list[str]] = None,
43
+ ) -> Iterator[str]:
44
+ """
45
+ Gera texto com streaming usando TextIteratorStreamer.
46
+
47
+ Args:
48
+ pipeline: Pipeline do transformers
49
+ prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
50
+ max_new_tokens: Número máximo de tokens a gerar
51
+ temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
52
+ top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
53
+ top_k: Top-k sampling (opcional, usa padrão do modelo se None)
54
+ do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
55
+ stop_sequences: Lista de sequências para parar a geração
56
+
57
+ Yields:
58
+ Tokens gerados um por vez
59
+ """
60
+ # Obtém o modelo e tokenizer do pipeline
61
+ model = pipeline.model
62
+ tokenizer = pipeline.tokenizer
63
+
64
+ # Formata prompt se for lista de mensagens
65
+ if isinstance(prompt, list):
66
+ formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
67
+ else:
68
+ formatted_prompt = prompt
69
+
70
+ # Tokeniza o prompt
71
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
72
+
73
+ # Cria streamer
74
+ streamer = TextIteratorStreamer(
75
+ tokenizer,
76
+ skip_prompt=True,
77
+ skip_special_tokens=True,
78
+ )
79
+
80
+ # Configurações de geração (usa valores padrão do modelo se não especificados)
81
+ generation_kwargs = _build_generation_kwargs(
82
+ max_new_tokens=max_new_tokens,
83
+ do_sample=do_sample,
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ top_k=top_k,
87
+ streamer=streamer,
88
+ use_cache=True, # Usa cache de atenção para acelerar
89
+ )
90
+ generation_kwargs.update(inputs)
91
+
92
+ # Thread para geração
93
+ generation_thread = Thread(
94
+ target=model.generate,
95
+ kwargs=generation_kwargs,
96
+ )
97
+ generation_thread.start()
98
+
99
+ # Yield tokens conforme são gerados
100
+ for token in streamer:
101
+ if stop_sequences:
102
+ # Verifica se algum stop_sequence foi encontrado
103
+ for stop_seq in stop_sequences:
104
+ if stop_seq in token:
105
+ generation_thread.join(timeout=1.0)
106
+ return
107
+ yield token
108
+
109
+ generation_thread.join()
110
+
111
+
112
+ def generate_simple(
113
+ pipeline: Pipeline,
114
+ prompt: Union[str, List[Message]],
115
+ max_new_tokens: int = 512,
116
+ temperature: Optional[float] = None,
117
+ top_p: Optional[float] = None,
118
+ top_k: Optional[int] = None,
119
+ do_sample: bool = True,
120
+ num_return_sequences: int = 1,
121
+ ) -> str:
122
+ """
123
+ Gera texto sem streaming (mais simples, útil para testes).
124
+
125
+ Args:
126
+ pipeline: Pipeline do transformers
127
+ prompt: Texto de entrada (str) ou lista de mensagens (List[Message])
128
+ max_new_tokens: Número máximo de tokens a gerar
129
+ temperature: Temperatura para sampling (opcional, usa padrão do modelo se None)
130
+ top_p: Nucleus sampling (opcional, usa padrão do modelo se None)
131
+ top_k: Top-k sampling (opcional, usa padrão do modelo se None)
132
+ do_sample: Se True, usa sampling; caso contrário, usa greedy decoding
133
+ num_return_sequences: Número de sequências a retornar
134
+
135
+ Returns:
136
+ Texto gerado
137
+ """
138
+ # Formata prompt se for lista de mensagens
139
+ tokenizer = pipeline.tokenizer
140
+ if isinstance(prompt, list):
141
+ formatted_prompt = _format_chat_prompt(tokenizer, prompt, add_generation_prompt=True)
142
+ else:
143
+ formatted_prompt = prompt
144
+
145
+ # Prepara parâmetros do pipeline (usa valores padrão do modelo se não especificados)
146
+ pipeline_kwargs = _build_generation_kwargs(
147
+ max_new_tokens=max_new_tokens,
148
+ do_sample=do_sample,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ top_k=top_k,
152
+ num_return_sequences=num_return_sequences,
153
+ return_full_text=False,
154
+ )
155
+
156
+ outputs = pipeline(formatted_prompt, **pipeline_kwargs)
157
+
158
+ if num_return_sequences == 1:
159
+ return outputs[0]["generated_text"]
160
+ else:
161
+ return [output["generated_text"] for output in outputs]
162
+
src/backend/model_loader.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading utilities with Streamlit caching."""
2
+
3
+ import os
4
+ import streamlit as st
5
+ from pathlib import Path
6
+ from typing import Optional, Dict, Any, Tuple
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ pipeline,
11
+ Pipeline,
12
+ )
13
+ import torch
14
+
15
+ # Obtém token do Hugging Face (disponível automaticamente no Spaces)
16
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
17
+
18
+ # Define o diretório de cache dentro do projeto
19
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
20
+ MODELS_CACHE_DIR = PROJECT_ROOT / "models"
21
+ MODELS_CACHE_DIR.mkdir(exist_ok=True)
22
+
23
+
24
+ @st.cache_resource
25
+ def load_model(
26
+ model_name: str,
27
+ device_map: Optional[str] = "auto",
28
+ torch_dtype: Optional[torch.dtype] = None,
29
+ load_in_8bit: bool = False,
30
+ load_in_4bit: bool = False,
31
+ ) -> Tuple[Pipeline, Dict[str, Any]]:
32
+ """
33
+ Carrega um modelo do Hugging Face com cache do Streamlit.
34
+
35
+ Args:
36
+ model_name: Nome do modelo no Hugging Face (ex: 'microsoft/DialoGPT-medium')
37
+ device_map: Mapeamento de dispositivo ('auto', 'cpu', 'cuda', etc.)
38
+ torch_dtype: Tipo de dados do torch (ex: torch.float16)
39
+ load_in_8bit: Se True, carrega modelo quantizado em 8-bit
40
+ load_in_4bit: Se True, carrega modelo quantizado em 4-bit
41
+
42
+ Returns:
43
+ Tupla contendo (pipeline, model_info)
44
+ """
45
+ try:
46
+ # Detecta dispositivo disponível
47
+ has_cuda = torch.cuda.is_available()
48
+
49
+ # Determina o dtype padrão
50
+ if torch_dtype is None:
51
+ if has_cuda:
52
+ torch_dtype = torch.float16
53
+ else:
54
+ torch_dtype = torch.float32
55
+
56
+ # Ajusta device_map: se não há GPU ou device_map é "auto" sem GPU, usa None
57
+ if device_map == "auto" and not has_cuda:
58
+ device_map = None
59
+ elif device_map == "auto" and has_cuda:
60
+ device_map = "auto"
61
+
62
+ # Configurações de quantização
63
+ model_kwargs = {
64
+ "torch_dtype": torch_dtype,
65
+ }
66
+
67
+ # Só adiciona device_map se não for None
68
+ if device_map is not None:
69
+ model_kwargs["device_map"] = device_map
70
+
71
+ if load_in_8bit or load_in_4bit:
72
+ try:
73
+ from transformers import BitsAndBytesConfig
74
+
75
+ quantization_config = BitsAndBytesConfig(
76
+ load_in_8bit=load_in_8bit,
77
+ load_in_4bit=load_in_4bit,
78
+ )
79
+ model_kwargs["quantization_config"] = quantization_config
80
+ except ImportError:
81
+ st.warning("bitsandbytes não está instalado. Quantização desabilitada.")
82
+
83
+ # Carrega tokenizer e modelo usando cache do projeto
84
+ cache_dir = str(MODELS_CACHE_DIR)
85
+
86
+ # Prepara kwargs com token de autenticação se disponível
87
+ hf_kwargs = {"cache_dir": cache_dir}
88
+ if HF_TOKEN:
89
+ hf_kwargs["token"] = HF_TOKEN
90
+
91
+ tokenizer = AutoTokenizer.from_pretrained(
92
+ model_name,
93
+ **hf_kwargs
94
+ )
95
+
96
+ # Adiciona pad_token se não existir
97
+ if tokenizer.pad_token is None:
98
+ tokenizer.pad_token = tokenizer.eos_token
99
+
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ model_name,
102
+ **hf_kwargs,
103
+ **model_kwargs
104
+ )
105
+
106
+ # Move modelo para CPU se não há GPU e device_map não foi usado
107
+ if device_map is None and not has_cuda:
108
+ model = model.to("cpu")
109
+
110
+ # Cria pipeline
111
+ pipeline_kwargs = {
112
+ "model": model,
113
+ "tokenizer": tokenizer,
114
+ }
115
+
116
+ # Só adiciona device ao pipeline se não usar device_map no modelo
117
+ if device_map is None:
118
+ pipeline_kwargs["device"] = 0 if has_cuda else -1
119
+ else:
120
+ pipeline_kwargs["device_map"] = device_map
121
+
122
+ pipe = pipeline("text-generation", **pipeline_kwargs)
123
+
124
+ # Informações do modelo
125
+ model_info = {
126
+ "model_name": model_name,
127
+ "device": str(next(model.parameters()).device),
128
+ "dtype": str(torch_dtype),
129
+ "quantized": load_in_8bit or load_in_4bit,
130
+ "cache_dir": cache_dir,
131
+ }
132
+
133
+ return pipe, model_info
134
+
135
+ except Exception as e:
136
+ st.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
137
+ raise
138
+
src/config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configurações do projeto."""
2
+
3
+ # Lista de modelos pré-selecionados para teste
4
+ PRESELECTED_MODELS = [
5
+ "Qwen/Qwen3-0.6B", # Modelo padrão
6
+ "google/gemma-3-4b-it",
7
+ "google/gemma-3-1b-it",
8
+ "google/gemma-3-270m-it",
9
+ "Qwen/Qwen2.5-0.5B-Instruct",
10
+ "Qwen/Qwen2.5-0.5B",
11
+ "facebook/MobileLLM-R1-950M",
12
+ ]
13
+
14
+ # Modelos que requerem autenticação (gated)
15
+ GATED_MODELS = {
16
+ "google/gemma-3-4b-it",
17
+ "google/gemma-3-1b-it",
18
+ "google/gemma-3-270m-it",
19
+ }
20
+
21
+ # Informações sobre os modelos (para exibição)
22
+ MODEL_INFO = {
23
+ "google/gemma-3-4b-it": {
24
+ "name": "Gemma 3 4B IT",
25
+ "params": "4 bilhões",
26
+ "family": "Google Gemma",
27
+ },
28
+ "google/gemma-3-1b-it": {
29
+ "name": "Gemma 3 1B IT",
30
+ "params": "1 bilhão",
31
+ "family": "Google Gemma",
32
+ },
33
+ "google/gemma-3-270m-it": {
34
+ "name": "Gemma 3 270M IT",
35
+ "params": "270 milhões",
36
+ "family": "Google Gemma",
37
+ },
38
+ "Qwen/Qwen3-0.6B": {
39
+ "name": "Qwen3 0.6B",
40
+ "params": "600 milhões",
41
+ "family": "Qwen",
42
+ },
43
+ "Qwen/Qwen2.5-0.5B-Instruct": {
44
+ "name": "Qwen2.5 0.5B Instruct",
45
+ "params": "500 milhões",
46
+ "family": "Qwen",
47
+ },
48
+ "Qwen/Qwen2.5-0.5B": {
49
+ "name": "Qwen2.5 0.5B",
50
+ "params": "500 milhões",
51
+ "family": "Qwen",
52
+ },
53
+ "facebook/MobileLLM-R1-950M": {
54
+ "name": "MobileLLM R1 950M",
55
+ "params": "950 milhões",
56
+ "family": "Facebook",
57
+ },
58
+ }
59
+
60
+
61
+ def get_model_label(model_id: str) -> str:
62
+ """Retorna label amigável para um modelo."""
63
+ if model_id in MODEL_INFO:
64
+ info = MODEL_INFO[model_id]
65
+ return f"{info['name']} ({info['params']})"
66
+ return model_id
67
+
68
+
69
+ def get_model_options() -> list[tuple[str, str]]:
70
+ """Retorna lista de tuplas (label, model_id) para uso em selectbox."""
71
+ return [(get_model_label(model_id), model_id) for model_id in PRESELECTED_MODELS]
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))