File size: 2,242 Bytes
3d48e06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Template for API integration script for {{phase_name}} (using Flask example)
from flask import Flask, request, jsonify
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch # Example PyTorch
app = Flask(__name__)
# --- Model and Tokenizer Loading ---
model_name = "models/fine_tuned_model" # Replace with your actual model path
tokenizer_name = "bert-base-uncased" # Replace with the tokenizer used for training, likely the base model tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
print("Model and tokenizer loaded successfully.")
model.eval() # Set model to evaluation mode
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
tokenizer = None
model = None
@app.route('/predict', methods=['POST'])
def predict():
if not tokenizer or not model:
return jsonify({"error": "Model or tokenizer not loaded."}), 500
try:
data = request.get_json()
text = data.get('text')
if not text:
return jsonify({"error": "No text input provided."}), 400
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") # Tokenize input text
with torch.no_grad(): # Inference mode
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item() # Get predicted class
# --- Map class ID to label (if applicable) ---
# Example for binary classification (class 0 and 1)
labels = ["Negative", "Positive"] # Replace with your actual labels
predicted_label = labels[predicted_class_id] if predicted_class_id < len(labels) else f"Class {predicted_class_id}"
return jsonify({"prediction": predicted_label, "class_id": predicted_class_id})
except Exception as e:
print(f"Prediction error: {e}")
return jsonify({"error": "Error during prediction."}), 500
@app.route('/', methods=['GET'])
def health_check():
return jsonify({"status": "API is healthy"}), 200
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=5000) # Run Flask app |