Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import os | |
| import pickle | |
| from torch.functional import F | |
| import numpy as np | |
| import gradio as gr | |
| import torchtext | |
| #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device('cpu') | |
| VOCAB_SIZE = 10000 | |
| MAX_LEN = 200 | |
| EMBEDDING_DIM = 100 | |
| N_UNITS = 128 | |
| VALIDATION_SPLIT = 0.2 | |
| SEED = 42 | |
| LOAD_MODEL = False | |
| BATCH_SIZE = 128 | |
| EPOCHS = 25 | |
| # loading model from checkpoint | |
| class LSTMModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_dim): | |
| super(LSTMModel, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| self.log_softmax = nn.LogSoftmax(dim=2) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x, _ = self.lstm(x) | |
| x = self.fc(x) | |
| return self.log_softmax(x) | |
| # loading model from checkpoint | |
| model = LSTMModel(VOCAB_SIZE, EMBEDDING_DIM, N_UNITS).to(device) | |
| device = 'cpu' | |
| checkpoint_path = 'recipe_generator_LSTM.pth' | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| print('Loaded model from checkpoint') | |
| def load_vocab(file_path): | |
| file_path = os.path.join(file_path) | |
| with open(file_path, 'rb') as input: | |
| vocab = pickle.load(input) | |
| print(f"Vocabulary loaded from {file_path}") | |
| return vocab | |
| vocab = load_vocab('vocab.pkl') | |
| class TextGenerator: | |
| def __init__(self, vocab, top_k=10): | |
| self.vocab = vocab | |
| self.top_k = top_k | |
| def sample_from(self, logits, temperature): | |
| probs = F.softmax(logits / temperature, dim=-1).cpu().numpy() | |
| return np.random.choice(len(probs), p=probs) | |
| def generate(self, model, device, start_prompt, max_tokens, temperature): | |
| model.eval() | |
| tokens = [self.vocab.get_stoi()[token] for token in start_prompt.split()] | |
| tokens = torch.LongTensor(tokens).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| output = model(tokens) | |
| next_token_logits = output[0, -1, :] | |
| next_token = self.sample_from(next_token_logits, temperature) | |
| tokens = torch.cat([tokens, torch.LongTensor([[next_token]]).to(device)], dim=1) | |
| generated_tokens = [token for token in tokens[0] if self.vocab.get_itos()[token] != '<pad>'] | |
| generated_text = ' '.join(self.vocab.get_itos()[token] for token in generated_tokens) | |
| return generated_text | |
| text_generator = TextGenerator(vocab=vocab, top_k=10) | |
| generated_text = text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5) | |
| print(f"\nGenerated Text: {generated_text}") | |
| def generate_recipe(): | |
| return text_generator.generate(model=model, device=device, start_prompt="recipe for", max_tokens=100, temperature=0.5) | |
| iface = gr.Interface( | |
| fn=generate_recipe, | |
| inputs=[], | |
| outputs="text", | |
| title="Recipe Generator", | |
| description="This is a LSTM based Recurrent Neural Network trained to generate recipes. Press submit to generate a new recipe that can sometimes provide humor!", | |
| ) | |
| iface.launch() |