Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import torch | |
| from src.model import TransformerModel | |
| from src.utils import load_vocab, tokenize | |
| import time | |
| import random | |
| import os | |
| # Configuration | |
| MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model | |
| VOCAB_PATH = 'vocab.json' | |
| EMBED_SIZE = 256 | |
| NUM_HEADS = 8 | |
| HIDDEN_DIM = 512 | |
| NUM_LAYERS = 4 | |
| DROPOUT = 0.1 | |
| MAX_LENGTH = 100 # Maximum tokens to generate | |
| # Title and Description | |
| st.title("3ed0k4 NLP Text Generation Model π") | |
| st.write("Enter a prompt, and the model will generate text based on your input. It will take 1 to 10 seconds to respond to simulate 'thinking'.") | |
| # Load vocabulary | |
| def load_resources(): | |
| vocab = load_vocab(VOCAB_PATH) | |
| return vocab | |
| vocab = load_resources() | |
| vocab_size = len(vocab) | |
| # Initialize model | |
| def load_model(): | |
| model = TransformerModel( | |
| vocab_size=vocab_size, | |
| embed_size=EMBED_SIZE, | |
| num_heads=NUM_HEADS, | |
| hidden_dim=HIDDEN_DIM, | |
| num_layers=NUM_LAYERS, | |
| dropout=DROPOUT | |
| ) | |
| if not os.path.exists(MODEL_PATH): | |
| st.error(f"Model file not found at {MODEL_PATH}. Please ensure the model is trained and the path is correct.") | |
| return None | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| def generate_text(prompt, max_length=MAX_LENGTH): | |
| tokens = tokenize(prompt) | |
| numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens] | |
| input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1 | |
| generated = numericalized.copy() | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device) | |
| outputs = model(input_seq, src_mask) | |
| next_token_logits = outputs[0, -1, :] | |
| next_token = torch.argmax(next_token_logits).item() | |
| if next_token == vocab.get('<PAD>', 0): | |
| break | |
| generated.append(next_token) | |
| input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0) | |
| # Convert numerical tokens back to words | |
| inv_vocab = {idx: word for word, idx in vocab.items()} | |
| generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated] | |
| return ' '.join(generated_tokens) | |
| # User Inputs | |
| prompt = st.text_input("Enter your prompt:", "") | |
| delay = st.slider("Select thinking delay (seconds):", min_value=1, max_value=10, value=3) | |
| if st.button("Generate"): | |
| if not model: | |
| st.error("Model is not loaded. Please check the model path.") | |
| elif prompt.strip() == "": | |
| st.warning("Please enter a prompt to generate text.") | |
| else: | |
| with st.spinner("Thinking..."): | |
| time.sleep(delay) | |
| response = generate_text(prompt) | |
| st.success("Here's the generated text:") | |
| st.write(response) | |