Spaces:
Paused
Paused
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Agregada esta importación | |
| import torch | |
| from huggingface_hub import login | |
| import os | |
| ################################################################## | |
| def setup_llama3_auth(): | |
| """Configurar autenticación para Llama 3""" | |
| if 'HUGGING_FACE_TOKEN_3' in st.secrets: | |
| token = st.secrets['HUGGING_FACE_TOKEN_3'] | |
| login(token) | |
| return True | |
| else: | |
| st.error("No se encontró el token de Llama 3 en los secrets") | |
| st.stop() | |
| return False | |
| class Llama3Demo: | |
| def __init__(self): | |
| setup_llama3_auth() | |
| self.model_name = "meta-llama/Llama-3.2-3B-Instruct" | |
| self._model = None | |
| self._tokenizer = None | |
| # Configuración de cuantización | |
| self.quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| def model(self): | |
| if self._model is None: | |
| try: | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| quantization_config=self.quantization_config, # Nueva forma de configurar cuantización | |
| token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token | |
| ) | |
| except Exception as e: | |
| st.error(f"Error cargando el modelo: {str(e)}") | |
| raise e | |
| return self._model | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| try: | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| token=st.secrets['HUGGING_FACE_TOKEN_3'] # Actualizado de use_auth_token a token | |
| ) | |
| except Exception as e: | |
| st.error(f"Error cargando el tokenizer: {str(e)}") | |
| raise e | |
| return self._tokenizer | |
| ################################################################## | |
| def generate_response(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.6, | |
| top_p: float = 0.85, repetition_penalty: float = 1.2, top_k: int = 50) -> str: | |
| formatted_prompt = f"""<|system|>You are a helpful AI assistant. Always provide accurate, | |
| detailed, and well-reasoned responses. If you're unsure about something, acknowledge the uncertainty. | |
| Break down complex topics into clear explanations.</s> | |
| <|user|>{prompt}</s> | |
| <|assistant|>""" | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
| if self.tokenizer.pad_token_id is None: | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_return_sequences=1, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| torch.cuda.empty_cache() | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("<|assistant|>")[-1].strip() | |
| ################################################################## | |
| def main(): | |
| st.set_page_config(page_title="Llama 3.2 Chat", page_icon="🦙") | |
| st.title("🦙 Llama 3.2 Chat") | |
| # Verificar configuración | |
| with st.expander("🔧 Status", expanded=True): | |
| try: | |
| token_status = setup_llama3_auth() | |
| st.write("Token Llama 3:", "✅" if token_status else "❌") | |
| if torch.cuda.is_available(): | |
| st.write("GPU:", torch.cuda.get_device_name(0)) | |
| st.write("Memoria GPU:", f"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") | |
| else: | |
| st.warning("GPU no disponible") | |
| except Exception as e: | |
| st.error(f"Error en configuración: {str(e)}") | |
| # Sidebar con controles de generación | |
| with st.sidebar: | |
| st.markdown("### Parámetros de Generación") | |
| generation_params = { | |
| 'temperature': st.slider( | |
| "Temperatura (creatividad vs precisión)", | |
| min_value=0.1, | |
| max_value=1.0, | |
| value=0.6, | |
| step=0.1, | |
| help="Valores más bajos = respuestas más precisas" | |
| ), | |
| 'max_new_tokens': st.slider( | |
| "Longitud máxima", | |
| min_value=64, | |
| max_value=1024, | |
| value=512, | |
| step=64, | |
| help="Longitud máxima de la respuesta" | |
| ), | |
| 'top_p': st.slider( | |
| "Top-p (núcleo de probabilidad)", | |
| min_value=0.1, | |
| max_value=1.0, | |
| value=0.85, | |
| step=0.05 | |
| ) | |
| } | |
| with st.expander("Parámetros Avanzados"): | |
| generation_params.update({ | |
| 'repetition_penalty': st.slider( | |
| "Penalización por repetición", | |
| min_value=1.0, | |
| max_value=2.0, | |
| value=1.2, | |
| step=0.1 | |
| ), | |
| 'top_k': st.slider( | |
| "Top-k tokens", | |
| min_value=1, | |
| max_value=100, | |
| value=50, | |
| step=1 | |
| ) | |
| }) | |
| st.markdown(""" | |
| ### Guía de Parámetros | |
| - **Temperatura**: Menor = más preciso, Mayor = más creativo | |
| - **Top-p**: Control sobre la variabilidad de respuestas | |
| - **Longitud**: Ajustar según necesidad de detalle | |
| """) | |
| if st.button("Limpiar Chat"): | |
| st.session_state.messages = [] | |
| st.experimental_rerun() | |
| # Inicializar el modelo | |
| if 'llama' not in st.session_state: | |
| with st.spinner("Inicializando Llama 3.2... esto puede tomar unos minutos..."): | |
| try: | |
| st.session_state.llama = Llama3Demo() | |
| except Exception as e: | |
| st.error("Error inicializando el modelo") | |
| st.stop() | |
| # Gestión del historial de chat | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Mostrar historial | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Interface de chat | |
| if prompt := st.chat_input("Escribe tu mensaje aquí"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| try: | |
| response = st.session_state.llama.generate_response(prompt, **generation_params) | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| except Exception as e: | |
| st.error(f"Error generando respuesta: {str(e)}") | |
| if __name__ == "__main__": | |
| main() | |