HSinghHuggingFace's picture
Huggingface app
f98cc3f
import streamlit as st
import torch
import tiktoken
import sys
import os
import logging
import warnings
# Configure logging and warnings
logging.getLogger('streamlit').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', message='.*torch.classes.*')
warnings.filterwarnings('ignore', category=FutureWarning)
# Add the project root to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config.model_config import GPTConfig
from src.models.gpt import LlamaForCausalLM
from src.utils.device_utils import get_device
@st.cache_resource
def load_model():
"""
Load and prepare the model for inference.
Returns the loaded model and device.
"""
device = get_device()
try:
# Load the checkpoint dictionary
checkpoint = torch.load('model.pt', map_location=device)
# Initialize model with config
config = GPTConfig()
model = LlamaForCausalLM(config)
# Load state dict - extract model_state_dict from checkpoint
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Remove cached rotary embedding buffers
state_dict.pop("model.rotary_emb.cos_cached", None)
state_dict.pop("model.rotary_emb.sin_cached", None)
model.load_state_dict(state_dict, strict=True)
# Prepare model for inference
model = model.float()
model.to(device)
model.eval()
return model, device
except Exception as e:
st.error(f"Detailed error during model loading: {str(e)}")
raise e
def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'):
"""
Generate text based on the input prompt.
Args:
model: The loaded GPT model
prompt: Input text prompt
max_length: Maximum number of tokens to generate
num_return_sequences: Number of different sequences to generate
device: Device to run inference on
Returns:
List of generated text sequences
"""
tokenizer = tiktoken.get_encoding('gpt2')
input_tokens = tokenizer.encode(prompt)
x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1)
x = x.to(device)
# Calculate final length (input length + requested additional tokens)
input_length = x.size(1)
target_length = input_length + max_length
# Generate text
with torch.no_grad():
while x.size(1) < target_length:
# Get predictions
logits, _ = model(x)
next_token_logits = logits[:, -1, :]
# Apply temperature to make the distribution more focused
probs = torch.softmax(next_token_logits / 0.8, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
# Append to the sequence
x = torch.cat((x, next_token), dim=1)
# Print token information
st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}")
# Decode generated sequences
generated_texts = []
for i in range(num_return_sequences):
tokens = x[i].tolist()
text = tokenizer.decode(tokens)
generated_texts.append(text)
return generated_texts
# Set page config
st.set_page_config(
page_title="SmolLM2-135 Text Generator",
page_icon="🐒",
layout="wide"
)
# Streamlit UI
st.title("🐒 SmolLM2-135 Text Generator")
st.markdown("""
This application uses a fine-tuned SmolLM2-135 model to generate text based on your prompts.
Enter your prompt below and adjust the generation parameters to create unique text sequences.
""")
# Create two columns for the interface
col1, col2 = st.columns([2, 1])
with col1:
# Input form
prompt = st.text_area(
"Enter your prompt:",
"Once upon a time",
height=100,
help="Enter the text you want the model to continue from"
)
with col2:
# Generation parameters
max_length = st.slider(
"Predict additional text of length:",
min_value=1,
max_value=50,
value=20,
help="Number of additional tokens to generate"
)
num_sequences = st.slider(
"Number of sequences to generate:",
min_value=1,
max_value=5,
value=1,
help="Generate multiple different sequences from the same prompt"
)
# Load model
try:
model, device = load_model()
model_status = st.success("Model loaded successfully! Ready to generate text.")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Generate button
if st.button("Generate", type="primary"):
if not prompt:
st.warning("Please enter a prompt first!")
else:
with st.spinner("Generating text..."):
try:
generated_texts = generate_text(
model=model,
prompt=prompt,
max_length=max_length,
num_return_sequences=num_sequences,
device=device
)
# Display results
st.subheader("Generated Text:")
for i, text in enumerate(generated_texts, 1):
with st.expander(f"Sequence {i}", expanded=True):
st.write(text)
except Exception as e:
st.error(f"Error during text generation: {str(e)}")
# Add footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>Built with Streamlit and PyTorch | SmolLM2-135 Model</p>
</div>
""", unsafe_allow_html=True)