|
|
import streamlit as st |
|
|
import torch |
|
|
import tiktoken |
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
import warnings |
|
|
|
|
|
|
|
|
logging.getLogger('streamlit').setLevel(logging.ERROR) |
|
|
warnings.filterwarnings('ignore', message='.*torch.classes.*') |
|
|
warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from src.config.model_config import GPTConfig |
|
|
from src.models.gpt import LlamaForCausalLM |
|
|
from src.utils.device_utils import get_device |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
""" |
|
|
Load and prepare the model for inference. |
|
|
Returns the loaded model and device. |
|
|
""" |
|
|
device = get_device() |
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint = torch.load('model.pt', map_location=device) |
|
|
|
|
|
|
|
|
config = GPTConfig() |
|
|
model = LlamaForCausalLM(config) |
|
|
|
|
|
|
|
|
if "model_state_dict" in checkpoint: |
|
|
state_dict = checkpoint["model_state_dict"] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
state_dict.pop("model.rotary_emb.cos_cached", None) |
|
|
state_dict.pop("model.rotary_emb.sin_cached", None) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|
model = model.float() |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model, device |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Detailed error during model loading: {str(e)}") |
|
|
raise e |
|
|
|
|
|
def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'): |
|
|
""" |
|
|
Generate text based on the input prompt. |
|
|
|
|
|
Args: |
|
|
model: The loaded GPT model |
|
|
prompt: Input text prompt |
|
|
max_length: Maximum number of tokens to generate |
|
|
num_return_sequences: Number of different sequences to generate |
|
|
device: Device to run inference on |
|
|
|
|
|
Returns: |
|
|
List of generated text sequences |
|
|
""" |
|
|
tokenizer = tiktoken.get_encoding('gpt2') |
|
|
input_tokens = tokenizer.encode(prompt) |
|
|
x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1) |
|
|
x = x.to(device) |
|
|
|
|
|
|
|
|
input_length = x.size(1) |
|
|
target_length = input_length + max_length |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
while x.size(1) < target_length: |
|
|
|
|
|
logits, _ = model(x) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
probs = torch.softmax(next_token_logits / 0.8, dim=-1) |
|
|
|
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
|
|
|
|
|
|
st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}") |
|
|
|
|
|
|
|
|
generated_texts = [] |
|
|
for i in range(num_return_sequences): |
|
|
tokens = x[i].tolist() |
|
|
text = tokenizer.decode(tokens) |
|
|
generated_texts.append(text) |
|
|
|
|
|
return generated_texts |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="SmolLM2-135 Text Generator", |
|
|
page_icon="π’", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
st.title("π’ SmolLM2-135 Text Generator") |
|
|
st.markdown(""" |
|
|
This application uses a fine-tuned SmolLM2-135 model to generate text based on your prompts. |
|
|
Enter your prompt below and adjust the generation parameters to create unique text sequences. |
|
|
""") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
|
|
with col1: |
|
|
|
|
|
prompt = st.text_area( |
|
|
"Enter your prompt:", |
|
|
"Once upon a time", |
|
|
height=100, |
|
|
help="Enter the text you want the model to continue from" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
|
|
|
max_length = st.slider( |
|
|
"Predict additional text of length:", |
|
|
min_value=1, |
|
|
max_value=50, |
|
|
value=20, |
|
|
help="Number of additional tokens to generate" |
|
|
) |
|
|
|
|
|
num_sequences = st.slider( |
|
|
"Number of sequences to generate:", |
|
|
min_value=1, |
|
|
max_value=5, |
|
|
value=1, |
|
|
help="Generate multiple different sequences from the same prompt" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
model, device = load_model() |
|
|
model_status = st.success("Model loaded successfully! Ready to generate text.") |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if st.button("Generate", type="primary"): |
|
|
if not prompt: |
|
|
st.warning("Please enter a prompt first!") |
|
|
else: |
|
|
with st.spinner("Generating text..."): |
|
|
try: |
|
|
generated_texts = generate_text( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
max_length=max_length, |
|
|
num_return_sequences=num_sequences, |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
st.subheader("Generated Text:") |
|
|
for i, text in enumerate(generated_texts, 1): |
|
|
with st.expander(f"Sequence {i}", expanded=True): |
|
|
st.write(text) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error during text generation: {str(e)}") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown(""" |
|
|
<div style='text-align: center'> |
|
|
<p>Built with Streamlit and PyTorch | SmolLM2-135 Model</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |