Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification | |
| import tensorflow as tf | |
| import re | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| def health_check(): | |
| return jsonify({"status": "healthy"}), 200 | |
| # --- Load Model and Tokenizer --- | |
| # Load the fine-tuned model and tokenizer from the saved directory | |
| model_path = "./saved_model" | |
| tokenizer = DistilBertTokenizer.from_pretrained(model_path) | |
| model = TFDistilBertForSequenceClassification.from_pretrained(model_path) | |
| print("Model and Tokenizer loaded successfully!") | |
| # --- Helper Function for Text Cleaning --- | |
| def clean_text(text): | |
| text = text.lower() | |
| text = re.sub(r'@[a-zA-Z0-9_]+', '', text) | |
| text = re.sub(r'https?://[A-Za-z0-9./]+', '', text) | |
| text = re.sub(r'[^a-zA-Z\s]', '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # --- Define the Prediction Endpoint --- | |
| def predict(): | |
| try: | |
| # Get text from the request's JSON body | |
| data = request.get_json() | |
| text = data['text'] | |
| if not text: | |
| return jsonify({"error": "Text field is required"}), 400 | |
| # Clean the input text | |
| cleaned_text = clean_text(text) | |
| # Tokenize the text | |
| inputs = tokenizer(cleaned_text, return_tensors="tf", truncation=True, padding=True, max_length=128) | |
| # Make a prediction | |
| outputs = model(inputs) | |
| logits = outputs.logits | |
| probabilities = tf.nn.softmax(logits, axis=-1)[0].numpy() | |
| prediction = tf.argmax(logits, axis=-1).numpy()[0] | |
| # Map prediction index to label | |
| labels = ['Non-Toxic', 'Toxic (Hate Speech/Offensive)'] | |
| result_label = labels[prediction] | |
| confidence = float(probabilities[prediction]) | |
| # Return the result | |
| return jsonify({ | |
| "text": text, | |
| "prediction": result_label, | |
| "confidence": f"{confidence:.4f}" | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # Run the app | |
| if __name__ == '__main__': | |
| # Use 0.0.0.0 to make it accessible from other Docker containers | |
| app.run(host='0.0.0.0', port=5000) |