File size: 5,938 Bytes
f98cc3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import torch
import tiktoken
import sys
import os
import logging
import warnings
# Configure logging and warnings
logging.getLogger('streamlit').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', message='.*torch.classes.*')
warnings.filterwarnings('ignore', category=FutureWarning)
# Add the project root to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config.model_config import GPTConfig
from src.models.gpt import LlamaForCausalLM
from src.utils.device_utils import get_device
@st.cache_resource
def load_model():
"""
Load and prepare the model for inference.
Returns the loaded model and device.
"""
device = get_device()
try:
# Load the checkpoint dictionary
checkpoint = torch.load('model.pt', map_location=device)
# Initialize model with config
config = GPTConfig()
model = LlamaForCausalLM(config)
# Load state dict - extract model_state_dict from checkpoint
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Remove cached rotary embedding buffers
state_dict.pop("model.rotary_emb.cos_cached", None)
state_dict.pop("model.rotary_emb.sin_cached", None)
model.load_state_dict(state_dict, strict=True)
# Prepare model for inference
model = model.float()
model.to(device)
model.eval()
return model, device
except Exception as e:
st.error(f"Detailed error during model loading: {str(e)}")
raise e
def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'):
"""
Generate text based on the input prompt.
Args:
model: The loaded GPT model
prompt: Input text prompt
max_length: Maximum number of tokens to generate
num_return_sequences: Number of different sequences to generate
device: Device to run inference on
Returns:
List of generated text sequences
"""
tokenizer = tiktoken.get_encoding('gpt2')
input_tokens = tokenizer.encode(prompt)
x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1)
x = x.to(device)
# Calculate final length (input length + requested additional tokens)
input_length = x.size(1)
target_length = input_length + max_length
# Generate text
with torch.no_grad():
while x.size(1) < target_length:
# Get predictions
logits, _ = model(x)
next_token_logits = logits[:, -1, :]
# Apply temperature to make the distribution more focused
probs = torch.softmax(next_token_logits / 0.8, dim=-1)
# Sample from the distribution
next_token = torch.multinomial(probs, num_samples=1)
# Append to the sequence
x = torch.cat((x, next_token), dim=1)
# Print token information
st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}")
# Decode generated sequences
generated_texts = []
for i in range(num_return_sequences):
tokens = x[i].tolist()
text = tokenizer.decode(tokens)
generated_texts.append(text)
return generated_texts
# Set page config
st.set_page_config(
page_title="SmolLM2-135 Text Generator",
page_icon="🐢",
layout="wide"
)
# Streamlit UI
st.title("🐢 SmolLM2-135 Text Generator")
st.markdown("""
This application uses a fine-tuned SmolLM2-135 model to generate text based on your prompts.
Enter your prompt below and adjust the generation parameters to create unique text sequences.
""")
# Create two columns for the interface
col1, col2 = st.columns([2, 1])
with col1:
# Input form
prompt = st.text_area(
"Enter your prompt:",
"Once upon a time",
height=100,
help="Enter the text you want the model to continue from"
)
with col2:
# Generation parameters
max_length = st.slider(
"Predict additional text of length:",
min_value=1,
max_value=50,
value=20,
help="Number of additional tokens to generate"
)
num_sequences = st.slider(
"Number of sequences to generate:",
min_value=1,
max_value=5,
value=1,
help="Generate multiple different sequences from the same prompt"
)
# Load model
try:
model, device = load_model()
model_status = st.success("Model loaded successfully! Ready to generate text.")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Generate button
if st.button("Generate", type="primary"):
if not prompt:
st.warning("Please enter a prompt first!")
else:
with st.spinner("Generating text..."):
try:
generated_texts = generate_text(
model=model,
prompt=prompt,
max_length=max_length,
num_return_sequences=num_sequences,
device=device
)
# Display results
st.subheader("Generated Text:")
for i, text in enumerate(generated_texts, 1):
with st.expander(f"Sequence {i}", expanded=True):
st.write(text)
except Exception as e:
st.error(f"Error during text generation: {str(e)}")
# Add footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>Built with Streamlit and PyTorch | SmolLM2-135 Model</p>
</div>
""", unsafe_allow_html=True) |