baseerah-api / app.py
Reema Alharbi
first commit
4120fd2
import pickle, re, numpy as np, torch, os
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModel, AutoTokenizer as AT, AutoModelForSequenceClassification as AM
import joblib
device = torch.device("cpu")
marbert_tokenizer = AutoTokenizer.from_pretrained("UBC-NLP/MARBERT")
marbert_model = AutoModel.from_pretrained("UBC-NLP/MARBERT").to(device)
marbert_model.eval()
svm_model = joblib.load("arabic_model/final_arabic_model.pkl")
le = joblib.load("arabic_model/final_label_encoder.pkl")
en_tokenizer = AT.from_pretrained("./english_models")
en_model = AM.from_pretrained("./english_models")
en_model.eval()
EN_LABELS = ["negative", "neutral", "positive"]
def clean_arabic(text):
text = re.sub(r'[أإآ]', 'ا', text)
text = re.sub(r'ى', 'ي', text)
text = re.sub(r'ة', 'ه', text)
text = re.sub(r'[\u064B-\u0652]', '', text)
return text.strip()
def get_vector(text):
inputs = marbert_tokenizer(text, return_tensors="pt",
truncation=True, padding=True, max_length=128).to(device)
with torch.no_grad():
output = marbert_model(**inputs)
return output.last_hidden_state.mean(dim=1).cpu().numpy()
def detect_lang(text):
ar = len(re.findall(r'[\u0600-\u06FF]', text))
return "arabic" if ar > len(text) * 0.3 else "english"
app = Flask(__name__)
CORS(app, origins="*")
@app.route('/predict', methods=['POST', 'OPTIONS'])
def predict():
if request.method == 'OPTIONS':
return jsonify({}), 200
try:
data = request.get_json()
text = data.get('text', '').strip()
lang = detect_lang(text)
if lang == "arabic":
cleaned = clean_arabic(text)
vec = get_vector(cleaned)
pred = svm_model.predict(vec)[0]
sentiment = le.inverse_transform([pred])[0]
if sentiment in ["ايجابي","إيجابي","positive"]: sentiment = "positive"
elif sentiment in ["سلبي","negative"]: sentiment = "negative"
else: sentiment = "neutral"
return jsonify({"sentiment": sentiment, "language": "arabic", "model": "MARBERT+SVM"})
else:
inputs = en_tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
outputs = en_model(**inputs)
idx = torch.argmax(outputs.logits, dim=1).item()
return jsonify({"sentiment": EN_LABELS[idx], "language": "english", "model": "RoBERTa"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "running"})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)