Vo Nhu Tu Anh
Upload 6 files
d76c0bc verified
Raw
History Blame Contribute Delete
2.74 kB
import platform
import pickle
import os # to read environment variables
from flask import Flask, jsonify, request
from flask_cors import CORS
# The deployment pickle references classes defined in sentiment_deploy_ensemble.py.
# Importing the module here makes those classes resolvable when pickle.load
# runs in this (separate) process. Ship sentiment_deploy_ensemble.py alongside app.py.
import sentiment_deploy_ensemble # noqa: F401 (needed for unpickling)
# ////////////////////////////////////////////////////////////////////////
# TODO: Adjust the code part below
# ////////////////////////////////////////////////////////////////////////
GROUP_ID = 'Modelling-Giants'
MODEL_FILE = 'route_c_ensemble_fp16.model' # relative path to your model file (Space root)
MODEL_VERSION = 'v1.0-ensemble'
# Adjust the function below so that it calls your vectorizer and
# classifier functions packaged in the .model file.
def batch_predict(model, items):
# Pull the batch of texts once (the ensemble classifier batches internally).
texts = [item['text'] for item in items]
# 'vectorizer' is a pass-through that applies the same light cleaning used
# at training time; per-model tokenisation happens inside the classifier.
X = model['vectorizer'].transform(texts)
# 'classifier' returns one integer API label in {-1, 0, 1} per input,
# in the same order as `texts`.
labels = model['classifier'].predict(X)
results = []
for item, label in zip(items, labels):
results.append({
"id": item['id'],
"label": int(label),
})
return results
# ////////////////////////////////////////////////////////////////////////
# You should not modify the code below.
# ////////////////////////////////////////////////////////////////////////
app = Flask(__name__) # set up app
CORS(app) # set up CORS policies
# load model file
with open(MODEL_FILE, 'rb') as file:
model = pickle.load(file)
# define meta-data for API
meta_data = {
"groupID": GROUP_ID,
"modelFile": MODEL_FILE,
"modelVersion": MODEL_VERSION,
"pythonVersion": platform.python_version()
}
# api route
@app.route("/", methods=['GET', 'POST'])
def main():
if request.method == 'POST':
items = request.json['items']
return jsonify({"items": batch_predict(model, items)}) # batch predictions
else:
return jsonify({"meta": meta_data}) # meta data
# start the api server when running the script
if __name__ == "__main__":
# Look for a cloud-provided port variable (e.g., 7860 on Hugging Face),
# but fallback to 8000 for local execution.
port = int(os.environ.get("PORT", 8000))
app.run(host="0.0.0.0", port=port, debug=True)