Spaces:
Sleeping
Sleeping
| # src/app.py | |
| from flask import Flask, request, render_template | |
| import torch | |
| from model import TransformerModel | |
| from utils import load_vocab, tokenize | |
| import time | |
| import random | |
| import os | |
| app = Flask(__name__, template_folder='templates') | |
| # Configuration | |
| MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model | |
| VOCAB_PATH = 'vocab.json' | |
| EMBED_SIZE = 256 | |
| NUM_HEADS = 8 | |
| HIDDEN_DIM = 512 | |
| NUM_LAYERS = 4 | |
| DROPOUT = 0.1 | |
| MAX_LENGTH = 100 # Maximum tokens to generate | |
| # Load vocabulary | |
| vocab = load_vocab(VOCAB_PATH) | |
| vocab_size = len(vocab) | |
| # Initialize model | |
| model = TransformerModel( | |
| vocab_size=vocab_size, | |
| embed_size=EMBED_SIZE, | |
| num_heads=NUM_HEADS, | |
| hidden_dim=HIDDEN_DIM, | |
| num_layers=NUM_LAYERS, | |
| dropout=DROPOUT | |
| ) | |
| # Load model weights | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please train the model first.") | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
| model.eval() | |
| def generate_text(prompt, max_length=MAX_LENGTH): | |
| tokens = tokenize(prompt) | |
| numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens] | |
| input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1 | |
| generated = numericalized.copy() | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device) | |
| outputs = model(input_seq, src_mask) | |
| next_token_logits = outputs[0, -1, :] | |
| next_token = torch.argmax(next_token_logits).item() | |
| if next_token == vocab['<PAD>']: | |
| break | |
| generated.append(next_token) | |
| input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0) | |
| # Convert numerical tokens back to words | |
| inv_vocab = {idx: word for word, idx in vocab.items()} | |
| generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated] | |
| return ' '.join(generated_tokens) | |
| def index(): | |
| return render_template('index.html') | |
| def chat(): | |
| message = request.form.get('message') | |
| if not message: | |
| return render_template('index.html') | |
| # Simulate thinking delay | |
| delay = random.randint(1, 10) | |
| print(f"Thinking for {delay} seconds...") | |
| time.sleep(delay) | |
| response = generate_text(message) | |
| return render_template('index.html', message=message, response=response) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=5000) | |