Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from streamlit import session_state as ss | |
| import os | |
| import login # Importa il file login.py che hai creato | |
| # Stile del chat input e sfondo della pagina | |
| st.markdown(""" | |
| <style> | |
| section[data-testid="stTextInput"] input { | |
| color: black !important; | |
| background-color: #F0F2F6 !important; | |
| font-size: 16px; | |
| border-radius: 10px; | |
| padding: 10px; | |
| } | |
| .main { | |
| background-color: #0A0A1A; | |
| color: #FFFFFF; | |
| } | |
| .stChatMessage div[data-baseweb="block"] { | |
| background-color: rgba(255, 255, 255, 0.1) !important; | |
| color: #FFFFFF !important; | |
| border-radius: 10px !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Inizializza lo stato di login se non esiste | |
| if "is_logged_in" not in st.session_state: | |
| st.session_state["is_logged_in"] = False | |
| # Mostra la pagina di login solo se l'utente non è loggato | |
| if not st.session_state["is_logged_in"]: | |
| login.login_page() | |
| st.stop() | |
| # Recupera le secrets da Hugging Face | |
| model_repo = st.secrets["MODEL_REPO"] | |
| hf_token = st.secrets["HF_TOKEN"] | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_repo, use_auth_token=hf_token) | |
| model.config.use_cache = True | |
| return tokenizer, model | |
| # Funzione per generare una risposta in tempo reale con supporto per l'interruzione | |
| def generate_llama_response_stream(user_input, tokenizer, model, max_length=512): | |
| eos_token = tokenizer.eos_token if tokenizer.eos_token else "" | |
| input_ids = tokenizer.encode(user_input + eos_token, return_tensors="pt") | |
| response_text = "" | |
| response_placeholder = st.empty() | |
| # Genera un token alla volta e aggiorna il testo in response_text | |
| for i in range(max_length): | |
| if ss.get("stop_generation", False): | |
| break # Interrompe il ciclo se l'utente ha premuto "stop" | |
| output = model.generate(input_ids, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id, use_cache=True) | |
| new_token_id = output[:, -1].item() | |
| new_token = tokenizer.decode([new_token_id], skip_special_tokens=True) | |
| response_text += new_token | |
| response_placeholder.markdown(f"RacoGPT: {response_text}", unsafe_allow_html=True) | |
| # Salva il testo parziale in session_state per preservarlo in caso di interruzione | |
| ss["response_text_partial"] = response_text | |
| # Aggiungi il nuovo token alla sequenza di input | |
| input_ids = torch.cat([input_ids, output[:, -1:]], dim=-1) | |
| # Interrompe se il token generato è <|endoftext|> o eos_token_id | |
| if new_token_id == tokenizer.eos_token_id: | |
| break | |
| # Reimposta lo stato di "stop" | |
| ss.stop_generation = False | |
| return response_text | |
| # Inizializza lo stato della sessione | |
| if 'is_chat_input_disabled' not in ss: | |
| ss.is_chat_input_disabled = False | |
| if 'msg' not in ss: | |
| ss.msg = [] | |
| if 'chat_history' not in ss: | |
| ss.chat_history = None | |
| if 'stop_generation' not in ss: | |
| ss.stop_generation = False | |
| # Carica il modello e tokenizer | |
| tokenizer, model = load_model() | |
| # Mostra la cronologia dei messaggi con le label personalizzate | |
| for message in ss.msg: | |
| if message["role"] == "user": | |
| with st.chat_message("user"): | |
| st.markdown(f"Tu: {message['content']}") | |
| elif message["role"] == "RacoGPT": | |
| with st.chat_message("RacoGPT"): | |
| st.markdown(f"RacoGPT: {message['content']}") | |
| # Contenitore per gestire la mutua esclusione tra input e pulsante di stop | |
| input_container = st.empty() | |
| if not ss.is_chat_input_disabled: | |
| # Mostra la barra di input per inviare il messaggio | |
| with input_container: | |
| prompt = st.chat_input("Scrivi il tuo messaggio...") | |
| if prompt: | |
| ss.msg.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| ss.is_chat_input_disabled = True | |
| st.markdown(f"Tu: {prompt}") | |
| st.rerun() | |
| else: | |
| # Mostra il pulsante di "Stop Generazione" al posto della barra di input | |
| with input_container: | |
| if st.button("🛑 Stop Generazione", key="stop_button"): | |
| ss.stop_generation = True # Interrompe la generazione impostando il flag | |
| # Genera la risposta del bot con digitazione in tempo reale | |
| with st.spinner("RacoGPT sta generando una risposta..."): | |
| response = generate_llama_response_stream(ss.msg[-1]['content'], tokenizer, model) | |
| # Usa il testo parziale se presente | |
| final_response = response or ss.get("response_text_partial", "") | |
| # Aggiungi la risposta finale nella cronologia dei messaggi | |
| ss.msg.append({"role": "RacoGPT", "content": final_response}) | |
| with st.chat_message("RacoGPT"): | |
| st.markdown(f"RacoGPT: {final_response}") | |
| # Pulisce il testo parziale dalla sessione e riabilita l'input | |
| ss.pop("response_text_partial", None) | |
| ss.is_chat_input_disabled = False | |
| # Rerun per aggiornare l'interfaccia | |
| st.rerun() |