Spaces:
Sleeping
Sleeping
| """Interface de chat com modelo de linguagem.""" | |
| import streamlit as st | |
| import base64 | |
| from pathlib import Path | |
| from backend import load_model, ChatModel | |
| from config import get_model_options | |
| from send_to_hub import send_message_buckt | |
| import uuid | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| st.set_page_config( | |
| layout="wide", | |
| page_title="Small LLM - Chat", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| def set_new_message_uuid(): | |
| st.session_state.message_uuid = uuid.uuid4() | |
| logging.info(f'[NEW MESSAGE UUID] {st.session_state.message_uuid}') | |
| def save_messages(): | |
| if "chat_model" not in st.session_state or not st.session_state.chat_model.conversation.messages: | |
| return | |
| messages = [ {'role':m.role, 'content':m.content} for m in st.session_state.chat_model.conversation.messages] | |
| message_uuid = st.session_state.message_uuid | |
| new_data = {} | |
| if "model_info" in st.session_state: | |
| for k,v in st.session_state.model_info.items(): | |
| new_data[str(k)] = str(v) | |
| new_data['model_name'] = str(st.session_state.get('model_name', 'unknown')) | |
| send_message_buckt(messages, message_uuid, new_data, REPO_ID = "CEIA-POSITIVO2/smol_chat") | |
| # Caminho da logo (relativo à raiz do projeto) | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| LOGO_PATH = PROJECT_ROOT / "positivo-logo.png" | |
| # Header com logo e título usando HTML/CSS para melhor controle | |
| logo_html = "" | |
| if LOGO_PATH.exists(): | |
| try: | |
| with open(LOGO_PATH, "rb") as img_file: | |
| img_base64 = base64.b64encode(img_file.read()).decode() | |
| logo_html = f'<img src="data:image/png;base64,{img_base64}" />' | |
| except Exception: | |
| logo_html = "" | |
| st.markdown(f""" | |
| <style> | |
| .logo-header {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 20px; | |
| margin-bottom: 0.5rem; | |
| }} | |
| .logo-header img {{ | |
| width: 90px; | |
| height: 90px; | |
| object-fit: contain; | |
| flex-shrink: 0; | |
| }} | |
| </style> | |
| <div class="logo-header"> | |
| {logo_html} | |
| <h1 style="margin: 0; padding: 0; display: inline-block;">Small LLM - Chat</h1> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def handle_model_load_error(model_name: str, error_msg: str): | |
| """Trata erros de carregamento de modelo.""" | |
| # Nota: Assumindo que GATED_MODELS está definido em config ou globalmente, | |
| # caso contrário essa verificação precisaria ser ajustada. | |
| st.error(f"❌ Erro ao carregar modelo {model_name}: {error_msg}") | |
| # ============================================================================ | |
| # LÓGICA DE CARREGAMENTO AUTOMÁTICO (Substitui a Sidebar) | |
| # ============================================================================ | |
| if "chat_model" not in st.session_state: | |
| with st.spinner("Inicializando o sistema e carregando o modelo..."): | |
| try: | |
| # 1. Obter opções de modelo | |
| model_options = get_model_options() # tuple(model_id, config) | |
| if not model_options: | |
| st.error("Nenhum modelo encontrado nas configurações.") | |
| st.stop() | |
| # 2. Selecionar automaticamente o PRIMEIRO modelo da lista | |
| selected_label, selected_model_config = model_options[2] | |
| logging.info(f'[MODEL] Loading {selected_label} {selected_model_config}') | |
| # Preparar configuração | |
| model_name = selected_model_config.pop('model_id') | |
| selected_model_config['dtype'] = 'bfloat16' | |
| print(f'Carregando modelo automático: {model_name}') | |
| # 3. Carregar o modelo | |
| pipeline, model_info = load_model( | |
| model_name, | |
| **selected_model_config | |
| ) | |
| # 4. Salvar na sessão | |
| chat_model = ChatModel(pipeline) | |
| st.session_state.chat_model = chat_model | |
| st.session_state.model_info = selected_model_config | |
| st.session_state.model_name = model_name | |
| set_new_message_uuid() | |
| # Força o rerun para atualizar a interface sem o spinner | |
| st.rerun() | |
| except Exception as e: | |
| handle_model_load_error(model_name if 'model_name' in locals() else "Unknown", str(e)) | |
| st.stop() | |
| # ============================================================================ | |
| # INTERFACE DE CHAT | |
| # ============================================================================ | |
| chat_model = st.session_state.chat_model | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Sincroniza mensagens do objeto chat_model com a session_state | |
| if len(chat_model.conversation.messages) != len(st.session_state.messages): | |
| st.session_state.messages = [ {"role": msg.role, "content": msg.content} for msg in chat_model.conversation.messages ] | |
| chat_placeholder = st.empty() | |
| with chat_placeholder.container(): | |
| for message in st.session_state.messages: | |
| role = message["role"] | |
| content = message["content"] | |
| if role == "system": | |
| continue | |
| with st.chat_message(role): | |
| st.markdown(content) | |
| if user_input := st.chat_input("Digite sua mensagem..."): | |
| # Limpa o historico | |
| st.session_state.messages = [] | |
| chat_model.conversation.clear() | |
| chat_placeholder.empty() | |
| set_new_message_uuid() | |
| chat_model.add_user_message(user_input) | |
| save_messages() | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.chat_message("assistant"): | |
| response_placeholder = st.empty() | |
| full_response = "" | |
| try: | |
| for token in chat_model.generate_streaming(max_new_tokens=4096): | |
| full_response += token | |
| response_placeholder.markdown(full_response) | |
| chat_model.add_assistant_message(full_response) | |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
| save_messages() | |
| except Exception as e: | |
| error_msg = f"Erro na geração: {str(e)}" | |
| st.error(error_msg) | |
| st.session_state.messages.append({"role": "assistant", "content": error_msg}) |