Spaces:
Sleeping
Sleeping
| from flask import Flask, jsonify, request | |
| from flask_cors import CORS | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| from history import load_dataset, get_unique_next_words_from_dataset | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Global variable to store the predicted words | |
| predicted_words = [] | |
| def generate_predicted_words(input_text): | |
| # Load the model and tokenizer | |
| model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| # Load the dataset | |
| dataset_name = "dataset.txt" | |
| dataset = load_dataset(dataset_name) | |
| history_next_text = get_unique_next_words_from_dataset(input_text, dataset) | |
| # Tokenize input | |
| inputs = tokenizer(input_text, return_tensors="pt").to("cpu") | |
| # Forward pass through the model | |
| with torch.no_grad(): | |
| outputs = model(**inputs, return_dict=True) | |
| logits = outputs.logits | |
| # Get the logits for the last token | |
| last_token_logits = logits[:, -1, :] | |
| probabilities = torch.softmax(last_token_logits, dim=-1) | |
| # Get the top 50 most probable next tokens | |
| top_50_probs, top_50_indices = torch.topk(probabilities, 50) | |
| top_50_tokens = [tokenizer.decode([idx]) for idx in top_50_indices[0]] | |
| words = [] | |
| removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '"] | |
| for token in top_50_tokens: | |
| if len(token) != 1 and token not in removable_words: | |
| words.append(token) | |
| return history_next_text + words # Return combined words | |
| def get_display_words(): | |
| # Get the count from query parameters | |
| count = int(request.args.get('count', 0)) | |
| if not predicted_words: | |
| # Generate the list only once if it's not generated yet | |
| input_text = "Are" # Default input, can be changed as needed | |
| predicted_words.extend(generate_predicted_words(input_text)) | |
| # Serve the slice of predicted words based on the count | |
| start_index = 9 * count | |
| end_index = start_index + 9 | |
| if start_index >= len(predicted_words): # Reset if out of bounds | |
| count = 0 | |
| start_index = 0 | |
| end_index = 9 | |
| display_words = predicted_words[start_index:end_index] | |
| return jsonify(display_words) | |
| def predict_words(): | |
| try: | |
| # Get the JSON data from the request | |
| data = request.get_json() | |
| print("data", data) | |
| # Check if the JSON was parsed properly | |
| if not isinstance(data, dict): | |
| return jsonify({'error': 'Invalid JSON format'}), 400 | |
| input_text = data.get('message', '') # Extract the message | |
| if not input_text: | |
| return jsonify({'error': 'No input text provided'}), 400 | |
| global predicted_words | |
| predicted_words = generate_predicted_words( | |
| input_text) # Generate words based on the input | |
| return jsonify(predicted_words) # Return the predicted words | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=5000, debug=True) | |