Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import torch | |
| import logging | |
| import gc | |
| import sys | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, List, Optional | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from tokenizers.normalizers import Sequence, Replace, Strip | |
| from tokenizers import Regex | |
| # ===================================================== | |
| # 🔧 تكوين البيئة والإعدادات | |
| # ===================================================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # إعدادات الذاكرة والكاش | |
| CACHE_DIR = "/tmp/huggingface_cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # تكوين متغيرات البيئة لـ Hugging Face | |
| os.environ.update({ | |
| "HF_HOME": CACHE_DIR, | |
| "TRANSFORMERS_CACHE": CACHE_DIR, | |
| "HF_DATASETS_CACHE": CACHE_DIR, | |
| "HUGGINGFACE_HUB_CACHE": CACHE_DIR, | |
| "TORCH_HOME": CACHE_DIR, | |
| "TOKENIZERS_PARALLELISM": "false", # منع مشاكل threading | |
| "TRANSFORMERS_OFFLINE": "0", # السماح بالتحميل من الإنترنت | |
| }) | |
| # إعدادات PyTorch للذاكرة | |
| if torch.cuda.is_available(): | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
| torch.backends.cudnn.benchmark = True | |
| # ===================================================== | |
| # 🚀 تحديد الجهاز (GPU أو CPU) | |
| # ===================================================== | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"🖥️ Using device: {device}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"🎮 CUDA Device: {torch.cuda.get_device_name(0)}") | |
| logger.info(f"💾 CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| # ===================================================== | |
| # 📊 خريطة الموديلات | |
| # ===================================================== | |
| label_mapping = { | |
| 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b', | |
| 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b', | |
| 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small', | |
| 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it', | |
| 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o', | |
| 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b', | |
| 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b', | |
| 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b', | |
| 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b', | |
| 39: 'text-davinci-002', 40: 'text-davinci-003' | |
| } | |
| # ===================================================== | |
| # 🤖 Model Manager - إدارة الموديلات | |
| # ===================================================== | |
| class ModelManager: | |
| def __init__(self): | |
| self.tokenizer = None | |
| self.models = [] | |
| self.models_loaded = False | |
| self.model_urls = [ | |
| "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12", | |
| "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22" | |
| ] | |
| def load_tokenizer(self): | |
| """تحميل الـ Tokenizer مع معالجة الأخطاء""" | |
| try: | |
| logger.info("📝 Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "answerdotai/ModernBERT-base", | |
| cache_dir=CACHE_DIR, | |
| use_fast=True, | |
| trust_remote_code=False | |
| ) | |
| # إعداد معالج النصوص | |
| try: | |
| newline_to_space = Replace(Regex(r'\s*\n\s*'), " ") | |
| join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2") | |
| self.tokenizer.backend_tokenizer.normalizer = Sequence([ | |
| self.tokenizer.backend_tokenizer.normalizer, | |
| join_hyphen_break, | |
| newline_to_space, | |
| Strip() | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Could not set custom normalizer: {e}") | |
| logger.info("✅ Tokenizer loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load tokenizer: {e}") | |
| return False | |
| def load_single_model(self, model_url=None, model_path=None, model_name="Model"): | |
| """تحميل موديل واحد مع معالجة شاملة للأخطاء""" | |
| try: | |
| logger.info(f"🤖 Loading {model_name}...") | |
| # إنشاء الموديل الأساسي | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| "answerdotai/ModernBERT-base", | |
| num_labels=41, | |
| cache_dir=CACHE_DIR, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=False | |
| ) | |
| # محاولة تحميل الأوزان | |
| if model_path and os.path.exists(model_path): | |
| logger.info(f"📁 Loading from local file: {model_path}") | |
| state_dict = torch.load(model_path, map_location=device, weights_only=True) | |
| base_model.load_state_dict(state_dict, strict=False) | |
| elif model_url: | |
| logger.info(f"🌐 Downloading weights from: {model_url}") | |
| try: | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| model_url, | |
| map_location=device, | |
| progress=True, | |
| check_hash=False, | |
| file_name=f"{model_name}.pt" | |
| ) | |
| base_model.load_state_dict(state_dict, strict=False) | |
| except Exception as url_error: | |
| logger.warning(f"⚠️ Could not load weights from URL: {url_error}") | |
| logger.info("📊 Using model with random initialization") | |
| else: | |
| logger.info("📊 Using model with random initialization") | |
| # نقل الموديل للجهاز المناسب | |
| model = base_model.to(device) | |
| model.eval() | |
| # تنظيف الذاكرة | |
| if 'state_dict' in locals(): | |
| del state_dict | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info(f"✅ {model_name} loaded successfully") | |
| return model | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load {model_name}: {e}") | |
| return None | |
| def load_models(self, max_models=2): | |
| """تحميل الموديلات بحد أقصى للذاكرة""" | |
| if self.models_loaded: | |
| logger.info("✨ Models already loaded") | |
| return True | |
| # تحميل الـ Tokenizer أولاً | |
| if not self.load_tokenizer(): | |
| return False | |
| # تحميل الموديلات | |
| logger.info(f"🚀 Loading up to {max_models} models...") | |
| # محاولة تحميل الملف المحلي أولاً | |
| local_model_path = "modernbert.bin" | |
| if os.path.exists(local_model_path): | |
| model = self.load_single_model( | |
| model_path=local_model_path, | |
| model_name="Model 1 (Local)" | |
| ) | |
| if model is not None: | |
| self.models.append(model) | |
| # تحميل الموديلات من URLs | |
| for i, url in enumerate(self.model_urls[:max_models - len(self.models)]): | |
| if len(self.models) >= max_models: | |
| break | |
| model = self.load_single_model( | |
| model_url=url, | |
| model_name=f"Model {len(self.models) + 1}" | |
| ) | |
| if model is not None: | |
| self.models.append(model) | |
| # التحقق من الذاكرة المتاحة | |
| if torch.cuda.is_available(): | |
| mem_allocated = torch.cuda.memory_allocated() / 1024**3 | |
| mem_reserved = torch.cuda.memory_reserved() / 1024**3 | |
| logger.info(f"💾 GPU Memory: {mem_allocated:.2f}GB allocated, {mem_reserved:.2f}GB reserved") | |
| # إيقاف التحميل إذا كانت الذاكرة ممتلئة | |
| if mem_allocated > 6: # حد أقصى 6GB | |
| logger.warning("⚠️ Memory limit reached, stopping model loading") | |
| break | |
| # التحقق من نجاح التحميل | |
| if len(self.models) > 0: | |
| self.models_loaded = True | |
| logger.info(f"✅ Successfully loaded {len(self.models)} models") | |
| return True | |
| else: | |
| logger.error("❌ No models could be loaded") | |
| return False | |
| def classify_text(self, text: str) -> Dict: | |
| """تحليل النص باستخدام الموديلات المحملة""" | |
| if not self.models_loaded or len(self.models) == 0: | |
| raise ValueError("No models loaded") | |
| # تنظيف النص | |
| cleaned_text = clean_text(text) | |
| if not cleaned_text.strip(): | |
| raise ValueError("Empty text after cleaning") | |
| # Tokenization | |
| try: | |
| inputs = self.tokenizer( | |
| cleaned_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(device) | |
| except Exception as e: | |
| logger.error(f"Tokenization error: {e}") | |
| raise ValueError(f"Failed to tokenize text: {e}") | |
| # الحصول على التنبؤات | |
| all_probabilities = [] | |
| with torch.no_grad(): | |
| for i, model in enumerate(self.models): | |
| try: | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=1) | |
| all_probabilities.append(probs) | |
| except Exception as e: | |
| logger.warning(f"Model {i+1} prediction failed: {e}") | |
| continue | |
| if not all_probabilities: | |
| raise ValueError("All models failed to make predictions") | |
| # حساب المتوسط (Soft Voting) | |
| averaged_probs = torch.mean(torch.stack(all_probabilities), dim=0) | |
| probabilities = averaged_probs[0] | |
| # حساب نسب Human vs AI | |
| human_prob = probabilities[24].item() | |
| ai_probs = probabilities.clone() | |
| ai_probs[24] = 0 # إزالة احتمالية Human | |
| ai_total_prob = ai_probs.sum().item() | |
| # التطبيع | |
| total = human_prob + ai_total_prob | |
| if total > 0: | |
| human_percentage = (human_prob / total) * 100 | |
| ai_percentage = (ai_total_prob / total) * 100 | |
| else: | |
| human_percentage = 50 | |
| ai_percentage = 50 | |
| # تحديد الموديل الأكثر احتمالاً | |
| ai_model_idx = torch.argmax(ai_probs).item() | |
| predicted_model = label_mapping.get(ai_model_idx, "Unknown") | |
| # أعلى 5 تنبؤات | |
| top_5_probs, top_5_indices = torch.topk(probabilities, 5) | |
| top_5_results = [] | |
| for prob, idx in zip(top_5_probs, top_5_indices): | |
| top_5_results.append({ | |
| "model": label_mapping.get(idx.item(), "Unknown"), | |
| "probability": round(prob.item() * 100, 2) | |
| }) | |
| return { | |
| "human_percentage": round(human_percentage, 2), | |
| "ai_percentage": round(ai_percentage, 2), | |
| "predicted_model": predicted_model, | |
| "top_5_predictions": top_5_results, | |
| "is_human": human_percentage > ai_percentage, | |
| "models_used": len(all_probabilities) | |
| } | |
| # ===================================================== | |
| # 🧹 دوال التنظيف والمعالجة | |
| # ===================================================== | |
| def clean_text(text: str) -> str: | |
| """تنظيف النص من المسافات الزائدة""" | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| text = re.sub(r'\s+([,.;:?!])', r'\1', text) | |
| return text.strip() | |
| def split_into_paragraphs(text: str) -> List[str]: | |
| """تقسيم النص إلى فقرات""" | |
| paragraphs = re.split(r'\n\s*\n', text.strip()) | |
| return [p.strip() for p in paragraphs if p.strip()] | |
| # ===================================================== | |
| # 🌐 FastAPI Application | |
| # ===================================================== | |
| app = FastAPI( | |
| title="ModernBERT AI Text Detector", | |
| description="كشف النصوص المكتوبة بواسطة الذكاء الاصطناعي", | |
| version="2.0.0" | |
| ) | |
| # إضافة CORS للسماح بالاستخدام من المتصفح | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # إنشاء مدير الموديلات | |
| model_manager = ModelManager() | |
| # ===================================================== | |
| # 📝 نماذج البيانات (Pydantic Models) | |
| # ===================================================== | |
| class TextInput(BaseModel): | |
| text: str | |
| analyze_paragraphs: Optional[bool] = False | |
| class SimpleTextInput(BaseModel): | |
| text: str | |
| class DetectionResult(BaseModel): | |
| success: bool | |
| code: int | |
| message: str | |
| data: Dict | |
| # ===================================================== | |
| # 🎯 API Endpoints | |
| # ===================================================== | |
| async def startup_event(): | |
| """تحميل الموديلات عند بداية التشغيل""" | |
| logger.info("=" * 50) | |
| logger.info("🚀 Starting ModernBERT AI Detector...") | |
| logger.info(f"🐍 Python version: {sys.version}") | |
| logger.info(f"🔥 PyTorch version: {torch.__version__}") | |
| logger.info("=" * 50) | |
| # محاولة تحميل الموديلات | |
| max_models = int(os.environ.get("MAX_MODELS", "2")) | |
| success = model_manager.load_models(max_models=max_models) | |
| if success: | |
| logger.info("✅ Application ready!") | |
| else: | |
| logger.error("⚠️ Failed to load models - API will return errors") | |
| async def root(): | |
| """الصفحة الرئيسية""" | |
| return { | |
| "message": "ModernBERT AI Text Detector API", | |
| "status": "online" if model_manager.models_loaded else "initializing", | |
| "models_loaded": len(model_manager.models), | |
| "device": str(device), | |
| "endpoints": { | |
| "analyze": "/analyze", | |
| "simple": "/analyze-simple", | |
| "health": "/health", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """فحص صحة الخدمة""" | |
| memory_info = {} | |
| if torch.cuda.is_available(): | |
| memory_info = { | |
| "gpu_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2), | |
| "gpu_reserved_gb": round(torch.cuda.memory_reserved() / 1024**3, 2) | |
| } | |
| return { | |
| "status": "healthy" if model_manager.models_loaded else "unhealthy", | |
| "models_loaded": len(model_manager.models), | |
| "device": str(device), | |
| "cuda_available": torch.cuda.is_available(), | |
| "memory_info": memory_info | |
| } | |
| async def analyze_text(data: TextInput): | |
| """ | |
| تحليل النص للكشف عن AI | |
| يحاكي نفس وظيفة Gradio classify_text | |
| """ | |
| try: | |
| # التحقق من النص | |
| text = data.text.strip() | |
| if not text: | |
| return DetectionResult( | |
| success=False, | |
| code=400, | |
| message="Empty input text", | |
| data={} | |
| ) | |
| # التأكد من تحميل الموديلات | |
| if not model_manager.models_loaded: | |
| # محاولة تحميل الموديلات | |
| if not model_manager.load_models(): | |
| return DetectionResult( | |
| success=False, | |
| code=503, | |
| message="Models not available", | |
| data={} | |
| ) | |
| # حساب عدد الكلمات | |
| total_words = len(text.split()) | |
| # التحليل الأساسي | |
| result = model_manager.classify_text(text) | |
| # النتائج الأساسية | |
| ai_percentage = result["ai_percentage"] | |
| human_percentage = result["human_percentage"] | |
| ai_words = int(total_words * (ai_percentage / 100)) | |
| # تحليل الفقرات إذا طُلب ذلك | |
| paragraphs_analysis = [] | |
| if data.analyze_paragraphs and ai_percentage > 50: | |
| paragraphs = split_into_paragraphs(text) | |
| recalc_ai_words = 0 | |
| recalc_total_words = 0 | |
| for para in paragraphs[:10]: # حد أقصى 10 فقرات | |
| if para.strip(): | |
| try: | |
| para_result = model_manager.classify_text(para) | |
| para_words = len(para.split()) | |
| recalc_total_words += para_words | |
| recalc_ai_words += para_words * (para_result["ai_percentage"] / 100) | |
| paragraphs_analysis.append({ | |
| "paragraph": para[:200] + "..." if len(para) > 200 else para, | |
| "ai_generated_score": para_result["ai_percentage"] / 100, | |
| "human_written_score": para_result["human_percentage"] / 100, | |
| "predicted_model": para_result["predicted_model"] | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Failed to analyze paragraph: {e}") | |
| # إعادة حساب النسب بناءً على الفقرات | |
| if recalc_total_words > 0: | |
| ai_percentage = round((recalc_ai_words / recalc_total_words) * 100, 2) | |
| human_percentage = round(100 - ai_percentage, 2) | |
| ai_words = int(recalc_ai_words) | |
| # إنشاء رسالة التغذية الراجعة | |
| if ai_percentage > 50: | |
| feedback = "Most of Your Text is AI/GPT Generated" | |
| else: | |
| feedback = "Most of Your Text Appears Human-Written" | |
| # إرجاع النتائج بنفس تنسيق الكود الأصلي | |
| return DetectionResult( | |
| success=True, | |
| code=200, | |
| message="analysis completed", | |
| data={ | |
| "fakePercentage": ai_percentage, | |
| "isHuman": human_percentage, | |
| "textWords": total_words, | |
| "aiWords": ai_words, | |
| "paragraphs": paragraphs_analysis, | |
| "predicted_model": result["predicted_model"], | |
| "feedback": feedback, | |
| "input_text": text[:500] + "..." if len(text) > 500 else text, | |
| "detected_language": "en", | |
| "top_5_predictions": result.get("top_5_predictions", []), | |
| "models_used": result.get("models_used", 1) | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Analysis error: {e}", exc_info=True) | |
| return DetectionResult( | |
| success=False, | |
| code=500, | |
| message=f"Analysis failed: {str(e)}", | |
| data={} | |
| ) | |
| async def analyze_simple(data: SimpleTextInput): | |
| """ | |
| تحليل مبسط - يرجع النتائج الأساسية فقط | |
| """ | |
| try: | |
| text = data.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Empty text") | |
| if not model_manager.models_loaded: | |
| if not model_manager.load_models(): | |
| raise HTTPException(status_code=503, detail="Models not available") | |
| result = model_manager.classify_text(text) | |
| return { | |
| "is_ai": result["ai_percentage"] > 50, | |
| "ai_score": result["ai_percentage"], | |
| "human_score": result["human_percentage"], | |
| "detected_model": result["predicted_model"] if result["ai_percentage"] > 50 else None, | |
| "confidence": max(result["ai_percentage"], result["human_percentage"]) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Simple analysis error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ===================================================== | |
| # 🏃 تشغيل التطبيق | |
| # ===================================================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # الحصول على الإعدادات من البيئة | |
| port = int(os.environ.get("PORT", 8000)) | |
| host = os.environ.get("HOST", "0.0.0.0") | |
| workers = int(os.environ.get("WORKERS", 1)) | |
| logger.info("=" * 50) | |
| logger.info(f"🌐 Starting server on {host}:{port}") | |
| logger.info(f"👷 Workers: {workers}") | |
| logger.info(f"📚 Documentation: http://{host}:{port}/docs") | |
| logger.info("=" * 50) | |