File size: 5,938 Bytes
f98cc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import torch
import tiktoken
import sys
import os
import logging
import warnings

# Configure logging and warnings
logging.getLogger('streamlit').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', message='.*torch.classes.*')
warnings.filterwarnings('ignore', category=FutureWarning)

# Add the project root to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.config.model_config import GPTConfig
from src.models.gpt import LlamaForCausalLM
from src.utils.device_utils import get_device

@st.cache_resource
def load_model():
    """
    Load and prepare the model for inference.
    Returns the loaded model and device.
    """
    device = get_device()
    
    try:
        # Load the checkpoint dictionary
        checkpoint = torch.load('model.pt', map_location=device)
        
        # Initialize model with config
        config = GPTConfig()
        model = LlamaForCausalLM(config)
        
        # Load state dict - extract model_state_dict from checkpoint
        if "model_state_dict" in checkpoint:
            state_dict = checkpoint["model_state_dict"]
        else:
            state_dict = checkpoint
            
        # Remove cached rotary embedding buffers
        state_dict.pop("model.rotary_emb.cos_cached", None)
        state_dict.pop("model.rotary_emb.sin_cached", None)
        
        model.load_state_dict(state_dict, strict=True)
        
        # Prepare model for inference
        model = model.float()
        model.to(device)
        model.eval()
        
        return model, device
        
    except Exception as e:
        st.error(f"Detailed error during model loading: {str(e)}")
        raise e

def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'):
    """
    Generate text based on the input prompt.
    
    Args:
        model: The loaded GPT model
        prompt: Input text prompt
        max_length: Maximum number of tokens to generate
        num_return_sequences: Number of different sequences to generate
        device: Device to run inference on
    
    Returns:
        List of generated text sequences
    """
    tokenizer = tiktoken.get_encoding('gpt2')
    input_tokens = tokenizer.encode(prompt)
    x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1)
    x = x.to(device)
    
    # Calculate final length (input length + requested additional tokens)
    input_length = x.size(1)
    target_length = input_length + max_length
    
    # Generate text
    with torch.no_grad():
        while x.size(1) < target_length:
            # Get predictions
            logits, _ = model(x)
            next_token_logits = logits[:, -1, :]
            
            # Apply temperature to make the distribution more focused
            probs = torch.softmax(next_token_logits / 0.8, dim=-1)
            
            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to the sequence
            x = torch.cat((x, next_token), dim=1)
    
    # Print token information
    st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}")
    
    # Decode generated sequences
    generated_texts = []
    for i in range(num_return_sequences):
        tokens = x[i].tolist()
        text = tokenizer.decode(tokens)
        generated_texts.append(text)
    
    return generated_texts

# Set page config
st.set_page_config(
    page_title="SmolLM2-135 Text Generator",
    page_icon="🐢",
    layout="wide"
)

# Streamlit UI
st.title("🐢 SmolLM2-135 Text Generator")
st.markdown("""
This application uses a fine-tuned SmolLM2-135 model to generate text based on your prompts.
Enter your prompt below and adjust the generation parameters to create unique text sequences.
""")

# Create two columns for the interface
col1, col2 = st.columns([2, 1])

with col1:
    # Input form
    prompt = st.text_area(
        "Enter your prompt:",
        "Once upon a time",
        height=100,
        help="Enter the text you want the model to continue from"
    )

with col2:
    # Generation parameters
    max_length = st.slider(
        "Predict additional text of length:",
        min_value=1,
        max_value=50,
        value=20,
        help="Number of additional tokens to generate"
    )
    
    num_sequences = st.slider(
        "Number of sequences to generate:",
        min_value=1,
        max_value=5,
        value=1,
        help="Generate multiple different sequences from the same prompt"
    )

# Load model
try:
    model, device = load_model()
    model_status = st.success("Model loaded successfully! Ready to generate text.")
except Exception as e:
    st.error(f"Error loading model: {str(e)}")
    st.stop()

# Generate button
if st.button("Generate", type="primary"):
    if not prompt:
        st.warning("Please enter a prompt first!")
    else:
        with st.spinner("Generating text..."):
            try:
                generated_texts = generate_text(
                    model=model,
                    prompt=prompt,
                    max_length=max_length,
                    num_return_sequences=num_sequences,
                    device=device
                )
                
                # Display results
                st.subheader("Generated Text:")
                for i, text in enumerate(generated_texts, 1):
                    with st.expander(f"Sequence {i}", expanded=True):
                        st.write(text)
                        
            except Exception as e:
                st.error(f"Error during text generation: {str(e)}")

# Add footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
    <p>Built with Streamlit and PyTorch | SmolLM2-135 Model</p>
</div>
""", unsafe_allow_html=True)