File size: 2,242 Bytes
3d48e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Template for API integration script for {{phase_name}} (using Flask example)

from flask import Flask, request, jsonify
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch # Example PyTorch

app = Flask(__name__)

# --- Model and Tokenizer Loading ---
model_name = "models/fine_tuned_model" # Replace with your actual model path
tokenizer_name = "bert-base-uncased" # Replace with the tokenizer used for training, likely the base model tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    print("Model and tokenizer loaded successfully.")
    model.eval() # Set model to evaluation mode
except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    tokenizer = None
    model = None


@app.route('/predict', methods=['POST'])
def predict():
    if not tokenizer or not model:
        return jsonify({"error": "Model or tokenizer not loaded."}), 500

    try:
        data = request.get_json()
        text = data.get('text')

        if not text:
            return jsonify({"error": "No text input provided."}), 400

        inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") # Tokenize input text

        with torch.no_grad(): # Inference mode
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_class_id = torch.argmax(logits, dim=-1).item() # Get predicted class

        # --- Map class ID to label (if applicable) ---
        # Example for binary classification (class 0 and 1)
        labels = ["Negative", "Positive"] # Replace with your actual labels
        predicted_label = labels[predicted_class_id] if predicted_class_id < len(labels) else f"Class {predicted_class_id}"


        return jsonify({"prediction": predicted_label, "class_id": predicted_class_id})

    except Exception as e:
        print(f"Prediction error: {e}")
        return jsonify({"error": "Error during prediction."}), 500

@app.route('/', methods=['GET'])
def health_check():
    return jsonify({"status": "API is healthy"}), 200


if __name__ == '__main__':
    app.run(debug=False, host='0.0.0.0', port=5000) # Run Flask app