import streamlit as st import torch import tiktoken import sys import os import logging import warnings # Configure logging and warnings logging.getLogger('streamlit').setLevel(logging.ERROR) warnings.filterwarnings('ignore', message='.*torch.classes.*') warnings.filterwarnings('ignore', category=FutureWarning) # Add the project root to Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.config.model_config import GPTConfig from src.models.gpt import LlamaForCausalLM from src.utils.device_utils import get_device @st.cache_resource def load_model(): """ Load and prepare the model for inference. Returns the loaded model and device. """ device = get_device() try: # Load the checkpoint dictionary checkpoint = torch.load('model.pt', map_location=device) # Initialize model with config config = GPTConfig() model = LlamaForCausalLM(config) # Load state dict - extract model_state_dict from checkpoint if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint # Remove cached rotary embedding buffers state_dict.pop("model.rotary_emb.cos_cached", None) state_dict.pop("model.rotary_emb.sin_cached", None) model.load_state_dict(state_dict, strict=True) # Prepare model for inference model = model.float() model.to(device) model.eval() return model, device except Exception as e: st.error(f"Detailed error during model loading: {str(e)}") raise e def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'): """ Generate text based on the input prompt. Args: model: The loaded GPT model prompt: Input text prompt max_length: Maximum number of tokens to generate num_return_sequences: Number of different sequences to generate device: Device to run inference on Returns: List of generated text sequences """ tokenizer = tiktoken.get_encoding('gpt2') input_tokens = tokenizer.encode(prompt) x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1) x = x.to(device) # Calculate final length (input length + requested additional tokens) input_length = x.size(1) target_length = input_length + max_length # Generate text with torch.no_grad(): while x.size(1) < target_length: # Get predictions logits, _ = model(x) next_token_logits = logits[:, -1, :] # Apply temperature to make the distribution more focused probs = torch.softmax(next_token_logits / 0.8, dim=-1) # Sample from the distribution next_token = torch.multinomial(probs, num_samples=1) # Append to the sequence x = torch.cat((x, next_token), dim=1) # Print token information st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}") # Decode generated sequences generated_texts = [] for i in range(num_return_sequences): tokens = x[i].tolist() text = tokenizer.decode(tokens) generated_texts.append(text) return generated_texts # Set page config st.set_page_config( page_title="SmolLM2-135 Text Generator", page_icon="🐢", layout="wide" ) # Streamlit UI st.title("🐢 SmolLM2-135 Text Generator") st.markdown(""" This application uses a fine-tuned SmolLM2-135 model to generate text based on your prompts. Enter your prompt below and adjust the generation parameters to create unique text sequences. """) # Create two columns for the interface col1, col2 = st.columns([2, 1]) with col1: # Input form prompt = st.text_area( "Enter your prompt:", "Once upon a time", height=100, help="Enter the text you want the model to continue from" ) with col2: # Generation parameters max_length = st.slider( "Predict additional text of length:", min_value=1, max_value=50, value=20, help="Number of additional tokens to generate" ) num_sequences = st.slider( "Number of sequences to generate:", min_value=1, max_value=5, value=1, help="Generate multiple different sequences from the same prompt" ) # Load model try: model, device = load_model() model_status = st.success("Model loaded successfully! Ready to generate text.") except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() # Generate button if st.button("Generate", type="primary"): if not prompt: st.warning("Please enter a prompt first!") else: with st.spinner("Generating text..."): try: generated_texts = generate_text( model=model, prompt=prompt, max_length=max_length, num_return_sequences=num_sequences, device=device ) # Display results st.subheader("Generated Text:") for i, text in enumerate(generated_texts, 1): with st.expander(f"Sequence {i}", expanded=True): st.write(text) except Exception as e: st.error(f"Error during text generation: {str(e)}") # Add footer st.markdown("---") st.markdown("""
Built with Streamlit and PyTorch | SmolLM2-135 Model