File size: 6,434 Bytes
91c131d
 
 
 
 
 
05c318d
23aa078
88881cd
a5721dd
14fffd7
c4935d0
88881cd
d911a66
 
 
566e08a
d911a66
 
88881cd
 
 
 
 
14fffd7
 
 
e8e5036
88881cd
509c44e
 
14fffd7
 
 
509c44e
14fffd7
509c44e
3821b68
88881cd
66a875b
91c131d
 
 
 
 
02c3825
 
 
 
 
 
 
 
91c131d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02c3825
91c131d
 
 
 
a5721dd
14fffd7
 
 
 
a5721dd
14fffd7
 
 
a5721dd
14fffd7
 
 
 
 
 
 
 
 
d911a66
14fffd7
66a875b
 
14fffd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d911a66
91c131d
 
 
 
14fffd7
 
 
 
 
 
 
093af2b
14fffd7
566e08a
14fffd7
566e08a
 
 
 
504bdc7
566e08a
 
 
 
 
91c131d
991811d
566e08a
e28ba46
566e08a
 
e28ba46
91c131d
e28ba46
 
 
91c131d
e28ba46
566e08a
91c131d
566e08a
e28ba46
 
 
 
 
 
14fffd7
e28ba46
 
 
91c131d
e28ba46
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Interface de chat com modelo de linguagem."""

import streamlit as st
import base64
from pathlib import Path
from backend import load_model, ChatModel
from config import get_model_options
from send_to_hub import send_message_buckt
import uuid
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

st.set_page_config(
    layout="wide",
    page_title="Small LLM - Chat",
    initial_sidebar_state="collapsed"
)

def set_new_message_uuid():
    st.session_state.message_uuid = uuid.uuid4()
    logging.info(f'[NEW MESSAGE UUID] {st.session_state.message_uuid}')

def save_messages():
    if "chat_model" not in st.session_state or not st.session_state.chat_model.conversation.messages:
        return
        
    messages = [ {'role':m.role, 'content':m.content} for m in st.session_state.chat_model.conversation.messages]
    message_uuid = st.session_state.message_uuid
    
    new_data = {}
    if "model_info" in st.session_state:
        for k,v in st.session_state.model_info.items():
            new_data[str(k)] = str(v)
    
    new_data['model_name'] = str(st.session_state.get('model_name', 'unknown'))
    
    send_message_buckt(messages, message_uuid, new_data, REPO_ID = "CEIA-POSITIVO2/smol_chat")


# Caminho da logo (relativo à raiz do projeto)
PROJECT_ROOT = Path(__file__).parent.parent
LOGO_PATH = PROJECT_ROOT / "positivo-logo.png"

# Header com logo e título usando HTML/CSS para melhor controle
logo_html = ""
if LOGO_PATH.exists():
    try:
        with open(LOGO_PATH, "rb") as img_file:
            img_base64 = base64.b64encode(img_file.read()).decode()
        logo_html = f'<img src="data:image/png;base64,{img_base64}" />'
    except Exception:
        logo_html = ""

st.markdown(f"""
<style>
.logo-header {{
    display: flex;
    align-items: center;
    gap: 20px;
    margin-bottom: 0.5rem;
}}
.logo-header img {{
    width: 90px;
    height: 90px;
    object-fit: contain;
    flex-shrink: 0;
}}
</style>
<div class="logo-header">
    {logo_html}
    <h1 style="margin: 0; padding: 0; display: inline-block;">Small LLM - Chat</h1>
</div>
""", unsafe_allow_html=True)

def handle_model_load_error(model_name: str, error_msg: str):
    """Trata erros de carregamento de modelo."""
    # Nota: Assumindo que GATED_MODELS está definido em config ou globalmente, 
    # caso contrário essa verificação precisaria ser ajustada.
    st.error(f"❌ Erro ao carregar modelo {model_name}: {error_msg}")

# ============================================================================
# LÓGICA DE CARREGAMENTO AUTOMÁTICO (Substitui a Sidebar)
# ============================================================================

if "chat_model" not in st.session_state:
    with st.spinner("Inicializando o sistema e carregando o modelo..."):
        try:
            # 1. Obter opções de modelo
            model_options = get_model_options() # tuple(model_id, config)
            
            if not model_options:
                st.error("Nenhum modelo encontrado nas configurações.")
                st.stop()

            # 2. Selecionar automaticamente o PRIMEIRO modelo da lista
            selected_label, selected_model_config = model_options[2]
            logging.info(f'[MODEL] Loading {selected_label} {selected_model_config}')
            
            # Preparar configuração
            model_name = selected_model_config.pop('model_id')
            selected_model_config['dtype'] = 'bfloat16'
            
            print(f'Carregando modelo automático: {model_name}')

            # 3. Carregar o modelo
            pipeline, model_info = load_model(
                model_name,
                **selected_model_config
            )
            
            # 4. Salvar na sessão
            chat_model = ChatModel(pipeline)
            st.session_state.chat_model = chat_model
            st.session_state.model_info = selected_model_config
            st.session_state.model_name = model_name
            set_new_message_uuid()
            
            # Força o rerun para atualizar a interface sem o spinner
            st.rerun()
            
        except Exception as e:
            handle_model_load_error(model_name if 'model_name' in locals() else "Unknown", str(e))
            st.stop()

# ============================================================================
# INTERFACE DE CHAT
# ============================================================================

chat_model = st.session_state.chat_model

if "messages" not in st.session_state:
    st.session_state.messages = []

# Sincroniza mensagens do objeto chat_model com a session_state
if len(chat_model.conversation.messages) != len(st.session_state.messages):
    st.session_state.messages = [ {"role": msg.role, "content": msg.content} for msg in chat_model.conversation.messages ]

chat_placeholder = st.empty()

with chat_placeholder.container():
    for message in st.session_state.messages:
        role = message["role"]
        content = message["content"]
        
        if role == "system":
            continue
       
        with st.chat_message(role):
            st.markdown(content)

if user_input := st.chat_input("Digite sua mensagem..."):
        # Limpa o historico
        st.session_state.messages = []
        chat_model.conversation.clear()
        chat_placeholder.empty()
        set_new_message_uuid()
    
        chat_model.add_user_message(user_input)
        save_messages()
        st.session_state.messages.append({"role": "user", "content": user_input})
    
        with st.chat_message("user"):
            st.markdown(user_input)
    
        with st.chat_message("assistant"):
            response_placeholder = st.empty()
            full_response = ""
            try:
                for token in chat_model.generate_streaming(max_new_tokens=4096):
                    full_response += token
                    response_placeholder.markdown(full_response)
            
                chat_model.add_assistant_message(full_response)
                st.session_state.messages.append({"role": "assistant", "content": full_response})
                save_messages()
            
            except Exception as e:
                error_msg = f"Erro na geração: {str(e)}"
                st.error(error_msg)
                st.session_state.messages.append({"role": "assistant", "content": error_msg})