Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import tiktoken | |
| from transformer import GPT, GPTConfig # Ensure you import your model class | |
| # Load the trained model | |
| def load_model(): | |
| config = GPTConfig() | |
| model = GPT(config) | |
| try: | |
| # Load the model with map_location to handle CPU-only environments | |
| model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False) | |
| model.eval() # Set the model to evaluation mode | |
| st.success("Model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return model | |
| # Load the tokenizer | |
| def load_tokenizer(): | |
| return tiktoken.get_encoding('gpt2') | |
| # Generate text function | |
| def generate_text(model, tokenizer, input_text, length, num_sequences): | |
| # Encode the input text | |
| input_ids = tokenizer.encode(input_text) | |
| input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension (shape: [1, T]) | |
| generated_sequences = [] | |
| for _ in range(num_sequences): | |
| # Generate additional tokens | |
| with torch.no_grad(): | |
| for _ in range(length): | |
| logits = model(input_tensor)[0] # Get logits | |
| next_token_logits = logits[:, -1, :] # Get the last token's logits | |
| next_token_probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution | |
| # Ensure the next_token has the correct shape for concatenation | |
| next_token = next_token.view(1, -1) # Reshape to [1, 1] if necessary | |
| input_tensor = torch.cat((input_tensor, next_token), dim=1) # Append the new token | |
| # Decode the generated tokens | |
| generated_sequences.append(tokenizer.decode(input_tensor[0].tolist())) | |
| return generated_sequences | |
| # Streamlit app layout | |
| st.title("GPT Text Generator") | |
| st.write("Enter your text and specify the length of additional text to generate.") | |
| input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters | |
| length = st.slider("Predict Additional Text of Length", 1, 50, 10) | |
| num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1) | |
| if st.button("Generate"): | |
| model = load_model() # Load the model for inference | |
| tokenizer = load_tokenizer() # Load the tokenizer | |
| st.write("Generating text...") | |
| generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences) | |
| st.write("Text generation complete.") | |
| st.write("Generated Texts:") | |
| for i, text in enumerate(generated_texts): | |
| st.subheader(f"Sequence {i + 1}") | |
| st.write(text) | |