import pickle, re, numpy as np, torch, os from flask import Flask, request, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModel, AutoTokenizer as AT, AutoModelForSequenceClassification as AM import joblib device = torch.device("cpu") marbert_tokenizer = AutoTokenizer.from_pretrained("UBC-NLP/MARBERT") marbert_model = AutoModel.from_pretrained("UBC-NLP/MARBERT").to(device) marbert_model.eval() svm_model = joblib.load("arabic_model/final_arabic_model.pkl") le = joblib.load("arabic_model/final_label_encoder.pkl") en_tokenizer = AT.from_pretrained("./english_models") en_model = AM.from_pretrained("./english_models") en_model.eval() EN_LABELS = ["negative", "neutral", "positive"] def clean_arabic(text): text = re.sub(r'[أإآ]', 'ا', text) text = re.sub(r'ى', 'ي', text) text = re.sub(r'ة', 'ه', text) text = re.sub(r'[\u064B-\u0652]', '', text) return text.strip() def get_vector(text): inputs = marbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) with torch.no_grad(): output = marbert_model(**inputs) return output.last_hidden_state.mean(dim=1).cpu().numpy() def detect_lang(text): ar = len(re.findall(r'[\u0600-\u06FF]', text)) return "arabic" if ar > len(text) * 0.3 else "english" app = Flask(__name__) CORS(app, origins="*") @app.route('/predict', methods=['POST', 'OPTIONS']) def predict(): if request.method == 'OPTIONS': return jsonify({}), 200 try: data = request.get_json() text = data.get('text', '').strip() lang = detect_lang(text) if lang == "arabic": cleaned = clean_arabic(text) vec = get_vector(cleaned) pred = svm_model.predict(vec)[0] sentiment = le.inverse_transform([pred])[0] if sentiment in ["ايجابي","إيجابي","positive"]: sentiment = "positive" elif sentiment in ["سلبي","negative"]: sentiment = "negative" else: sentiment = "neutral" return jsonify({"sentiment": sentiment, "language": "arabic", "model": "MARBERT+SVM"}) else: inputs = en_tokenizer(text, return_tensors="pt", truncation=True, max_length=128) with torch.no_grad(): outputs = en_model(**inputs) idx = torch.argmax(outputs.logits, dim=1).item() return jsonify({"sentiment": EN_LABELS[idx], "language": "english", "model": "RoBERTa"}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/health', methods=['GET']) def health(): return jsonify({"status": "running"}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)