Chatbot / scripts /code_templates /api_template.py.txt
rogerthat11's picture
push full macanism
3d48e06
# Template for API integration script for {{phase_name}} (using Flask example)
from flask import Flask, request, jsonify
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch # Example PyTorch
app = Flask(__name__)
# --- Model and Tokenizer Loading ---
model_name = "models/fine_tuned_model" # Replace with your actual model path
tokenizer_name = "bert-base-uncased" # Replace with the tokenizer used for training, likely the base model tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
print("Model and tokenizer loaded successfully.")
model.eval() # Set model to evaluation mode
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
tokenizer = None
model = None
@app.route('/predict', methods=['POST'])
def predict():
if not tokenizer or not model:
return jsonify({"error": "Model or tokenizer not loaded."}), 500
try:
data = request.get_json()
text = data.get('text')
if not text:
return jsonify({"error": "No text input provided."}), 400
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") # Tokenize input text
with torch.no_grad(): # Inference mode
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item() # Get predicted class
# --- Map class ID to label (if applicable) ---
# Example for binary classification (class 0 and 1)
labels = ["Negative", "Positive"] # Replace with your actual labels
predicted_label = labels[predicted_class_id] if predicted_class_id < len(labels) else f"Class {predicted_class_id}"
return jsonify({"prediction": predicted_label, "class_id": predicted_class_id})
except Exception as e:
print(f"Prediction error: {e}")
return jsonify({"error": "Error during prediction."}), 500
@app.route('/', methods=['GET'])
def health_check():
return jsonify({"status": "API is healthy"}), 200
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=5000) # Run Flask app