Spaces:
Running
Running
| """ | |
| Flask API with BETTER Translation Models + Redis Cache | |
| Uses high-quality models for both directions with Redis caching | |
| """ | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| import re | |
| import time | |
| import warnings | |
| from functools import lru_cache | |
| from typing import List | |
| from flask import Flask, jsonify, request | |
| from flask_caching import Cache | |
| warnings.filterwarnings("ignore") | |
| app = Flask(__name__) | |
| # Redis Cache configuration | |
| # REDIS_HOST = os.getenv("REDIS_HOST", "redis") | |
| # REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) | |
| # REDIS_DB = int(os.getenv("REDIS_DB", 0)) | |
| # CACHE_DEFAULT_TIMEOUT = int(os.getenv("CACHE_DEFAULT_TIMEOUT", 3600)) | |
| REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") | |
| CACHE_DEFAULT_TIMEOUT = int(os.getenv("CACHE_DEFAULT_TIMEOUT", 3600)) | |
| # cache_config = { | |
| # "CACHE_TYPE": "RedisCache", | |
| # "CACHE_REDIS_HOST": REDIS_HOST, | |
| # "CACHE_REDIS_PORT": REDIS_PORT, | |
| # "CACHE_REDIS_DB": REDIS_DB, | |
| # "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT, | |
| # "CACHE_KEY_PREFIX": "transliteration:", | |
| # } | |
| # print(f"📦 Redis Configuration:") | |
| # print(f" Host: {REDIS_HOST}") | |
| # print(f" Port: {REDIS_PORT}") | |
| # print(f" DB: {REDIS_DB}") | |
| # print(f" Timeout: {CACHE_DEFAULT_TIMEOUT}s") | |
| cache_config = { | |
| "CACHE_TYPE": "RedisCache", | |
| "CACHE_REDIS_URL": REDIS_URL, # MUST be rediss:// | |
| "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT, | |
| "CACHE_KEY_PREFIX": "transliteration:", | |
| } | |
| print(f"📦 Redis Configuration:") | |
| print(f" URL: {REDIS_URL.split('@')[-1]}") # Print only the endpoint for security | |
| print(f" Timeout: {CACHE_DEFAULT_TIMEOUT}s") | |
| # try: | |
| # cache = Cache(app, config=cache_config) | |
| # print("✅ Redis cache initialized successfully") | |
| # except Exception as e: | |
| # print(f"⚠️ Redis connection failed: {e}") | |
| # print(" Falling back to SimpleCache") | |
| # cache_config = { | |
| # "CACHE_TYPE": "SimpleCache", | |
| # "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT, | |
| # "CACHE_THRESHOLD": 10000, | |
| # } | |
| # cache = Cache(app, config=cache_config) | |
| try: | |
| cache = Cache(app, config=cache_config) | |
| with app.app_context(): | |
| cache.set("ping", "pong", timeout=10) | |
| if cache.get("ping") == "pong": | |
| print("✅ Upstash Redis connected successfully") | |
| else: | |
| raise Exception("Ping test failed") | |
| except Exception as e: | |
| print(f"⚠️ Redis connection failed: {e}") | |
| print(" Falling back to SimpleCache") | |
| cache = Cache( | |
| app, | |
| config={ | |
| "CACHE_TYPE": "SimpleCache", | |
| "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT, | |
| "CACHE_THRESHOLD": 10000, | |
| }, | |
| ) | |
| # Global variables | |
| ar_en_model = None | |
| en_ar_model = None | |
| ar_en_tokenizer = None | |
| en_ar_tokenizer = None | |
| device = None | |
| USE_TRANSFORMERS = True | |
| # Which models to use | |
| MODEL_CHOICE = os.getenv("MODEL_CHOICE", "opus-big") # Options: "opus-big", "marefa" | |
| # Performance metrics | |
| metrics = { | |
| "cache_hits": 0, | |
| "cache_misses": 0, | |
| "model_requests": 0, | |
| "total_requests": 0, | |
| "avg_response_time": 0.0, | |
| } | |
| print("🚀 Starting Transliteration API with Better Models + Redis Cache...") | |
| def load_models(): | |
| """Load BETTER translation models""" | |
| global \ | |
| ar_en_model, \ | |
| en_ar_model, \ | |
| ar_en_tokenizer, \ | |
| en_ar_tokenizer, \ | |
| device, \ | |
| USE_TRANSFORMERS, \ | |
| MODEL_CHOICE | |
| try: | |
| import torch | |
| from transformers import MarianMTModel, MarianTokenizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"🖥️ Using device: {device}") | |
| if MODEL_CHOICE == "opus-big": | |
| # OPTION 1: OPUS Big Models (BEST QUALITY - Recommended) | |
| print("📥 Loading OPUS-MT-BIG models (high quality)...") | |
| # Arabic to English (OPUS Big) | |
| print(" Loading AR→EN (opus-mt-tc-big-ar-en)...") | |
| ar_en_model_name = "Helsinki-NLP/opus-mt-tc-big-ar-en" | |
| ar_en_tokenizer = MarianTokenizer.from_pretrained(ar_en_model_name) | |
| ar_en_model = MarianMTModel.from_pretrained(ar_en_model_name).to(device) | |
| ar_en_model.eval() | |
| print(" ✅ AR→EN loaded") | |
| # English to Arabic (OPUS Big) - MUCH BETTER than basic model | |
| print(" Loading EN→AR (opus-mt-tc-big-en-ar)...") | |
| en_ar_model_name = "Helsinki-NLP/opus-mt-tc-big-en-ar" | |
| en_ar_tokenizer = MarianTokenizer.from_pretrained(en_ar_model_name) | |
| en_ar_model = MarianMTModel.from_pretrained(en_ar_model_name).to(device) | |
| en_ar_model.eval() | |
| print(" ✅ EN→AR loaded") | |
| print("🎉 OPUS-MT-BIG models loaded successfully!") | |
| elif MODEL_CHOICE == "marefa": | |
| # OPTION 2: Marefa Model (Specialized for Arabic) | |
| print("📥 Loading Marefa models (Arabic-specialized)...") | |
| # Arabic to English (OPUS Big - still best for this direction) | |
| print(" Loading AR→EN (opus-mt-tc-big-ar-en)...") | |
| ar_en_model_name = "Helsinki-NLP/opus-mt-tc-big-ar-en" | |
| ar_en_tokenizer = MarianTokenizer.from_pretrained(ar_en_model_name) | |
| ar_en_model = MarianMTModel.from_pretrained(ar_en_model_name).to(device) | |
| ar_en_model.eval() | |
| print(" ✅ AR→EN loaded") | |
| # English to Arabic (Marefa - Arabic specialized) | |
| print(" Loading EN→AR (marefa-mt-en-ar)...") | |
| en_ar_model_name = "marefa-nlp/marefa-mt-en-ar" | |
| en_ar_tokenizer = MarianTokenizer.from_pretrained(en_ar_model_name) | |
| en_ar_model = MarianMTModel.from_pretrained(en_ar_model_name).to(device) | |
| en_ar_model.eval() | |
| print(" ✅ EN→AR loaded (Marefa)") | |
| print("🎉 Marefa models loaded successfully!") | |
| USE_TRANSFORMERS = True | |
| return True | |
| except Exception as e: | |
| print(f"⚠️ Error loading models: {str(e)}") | |
| print("💡 Make sure you have enough memory and internet connection") | |
| USE_TRANSFORMERS = False | |
| return False | |
| def normalize_arabic(text: str) -> str: | |
| """Normalize Arabic text (cached in memory)""" | |
| text = re.sub(r"[\u064B-\u065F]", "", text) | |
| text = text.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا").replace("ٱ", "ا") | |
| text = text.replace("ى", "ي") | |
| text = text.replace("ة", "ه") | |
| text = text.replace("ؤ", "و") | |
| text = text.replace("ئ", "ي") | |
| return text.strip() | |
| def get_phonetic_variants(text: str) -> List[str]: | |
| """Get phonetic variants (cached in memory)""" | |
| variants = {text.lower()} | |
| base = text.lower() | |
| patterns = [ | |
| (r"a", "e"), | |
| (r"e", "a"), | |
| (r"een$", "ain"), | |
| (r"ain$", "een"), | |
| (r"(.)\1", r"\1"), | |
| ] | |
| for pattern, replacement in patterns: | |
| if re.search(pattern, base): | |
| variant = re.sub(pattern, replacement, base) | |
| if variant != base and len(variant) >= 2: | |
| variants.add(variant) | |
| return list(variants)[:5] | |
| def transliterate_with_model(text: str, from_lang: str, to_lang: str) -> List[str]: | |
| """Use high-quality transformer models""" | |
| global ar_en_model, en_ar_model, ar_en_tokenizer, en_ar_tokenizer, device | |
| if not USE_TRANSFORMERS: | |
| return [] | |
| try: | |
| import torch | |
| # Select model and tokenizer | |
| if from_lang == "ar" and to_lang == "en": | |
| model = ar_en_model | |
| tokenizer = ar_en_tokenizer | |
| elif from_lang == "en" and to_lang == "ar": | |
| model = en_ar_model | |
| tokenizer = en_ar_tokenizer | |
| # For OPUS Big EN→AR, need to add language token | |
| if MODEL_CHOICE == "opus-big": | |
| text = ">>ara<< " + text | |
| else: | |
| return [] | |
| # Tokenize | |
| inputs = tokenizer([text], return_tensors="pt", padding=True).to(device) | |
| variants = set() | |
| with torch.no_grad(): | |
| # Method 1: Greedy decoding | |
| outputs = model.generate( | |
| **inputs, max_length=50, num_beams=1, do_sample=False | |
| ) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| if result: | |
| variants.add(result.lower() if to_lang == "en" else result) | |
| # Method 2: Beam search (best quality) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=50, | |
| num_beams=5, | |
| num_return_sequences=3, | |
| do_sample=False, | |
| early_stopping=True, | |
| ) | |
| for output in outputs: | |
| result = tokenizer.decode(output, skip_special_tokens=True).strip() | |
| if result: | |
| variants.add(result.lower() if to_lang == "en" else result) | |
| # Method 3: Diverse beam search | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=50, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.8, | |
| num_return_sequences=2, | |
| ) | |
| for output in outputs: | |
| result = tokenizer.decode(output, skip_special_tokens=True).strip() | |
| if result: | |
| variants.add(result.lower() if to_lang == "en" else result) | |
| return list(variants)[:6] | |
| except Exception as e: | |
| print(f"Model error: {str(e)}") | |
| return [] | |
| def add_arabic_variants(text: str) -> List[str]: | |
| """Add common Arabic variants""" | |
| variants = {text} | |
| # Hamza variants | |
| if text.startswith("ا"): | |
| variants.add("أ" + text[1:]) | |
| variants.add("إ" + text[1:]) | |
| # Final ya | |
| if text.endswith("ي"): | |
| variants.add(text[:-1] + "ى") | |
| # Taa marbuta | |
| if text.endswith("ه"): | |
| variants.add(text[:-1] + "ة") | |
| return list(variants) | |
| # ============================================================================ | |
| # API ENDPOINTS | |
| # ============================================================================ | |
| def health_check(): | |
| """Health check with Redis status""" | |
| redis_status = "connected" | |
| try: | |
| # Test Redis connection | |
| cache.set("health_check", "ok", timeout=5) | |
| test_val = cache.get("health_check") | |
| if test_val != "ok": | |
| redis_status = "error" | |
| except Exception as e: | |
| redis_status = f"error: {str(e)}" | |
| return jsonify( | |
| { | |
| "status": "healthy", | |
| "models_loaded": ar_en_model is not None and en_ar_model is not None, | |
| "model_type": MODEL_CHOICE, | |
| "using_transformers": USE_TRANSFORMERS, | |
| "device": str(device) if device else "none", | |
| "ar_en_model": "opus-mt-tc-big-ar-en", | |
| "en_ar_model": "opus-mt-tc-big-en-ar" | |
| if MODEL_CHOICE == "opus-big" | |
| else "marefa-mt-en-ar", | |
| "cache": { | |
| "type": "Redis", | |
| "status": redis_status, | |
| "timeout": CACHE_DEFAULT_TIMEOUT, | |
| }, | |
| "metrics": metrics, | |
| } | |
| ) | |
| def get_stats(): | |
| """Performance statistics""" | |
| cache_total = metrics["cache_hits"] + metrics["cache_misses"] | |
| hit_rate = metrics["cache_hits"] / cache_total if cache_total > 0 else 0 | |
| # Get Redis info if available | |
| redis_info = {} | |
| try: | |
| if cache_config.get("CACHE_TYPE") == "RedisCache": | |
| redis_client = cache.cache._write_client | |
| info = redis_client.info("stats") | |
| redis_info = { | |
| "total_connections_received": info.get("total_connections_received", 0), | |
| "total_commands_processed": info.get("total_commands_processed", 0), | |
| "keyspace_hits": info.get("keyspace_hits", 0), | |
| "keyspace_misses": info.get("keyspace_misses", 0), | |
| } | |
| except Exception as e: | |
| redis_info = {"error": str(e)} | |
| return jsonify( | |
| { | |
| "cache": { | |
| "hits": metrics["cache_hits"], | |
| "misses": metrics["cache_misses"], | |
| "hit_rate": f"{hit_rate * 100:.2f}%", | |
| "redis": redis_info, | |
| }, | |
| "requests": { | |
| "total": metrics["total_requests"], | |
| "model_requests": metrics["model_requests"], | |
| }, | |
| "performance": { | |
| "avg_response_time_ms": f"{metrics['avg_response_time']:.2f}", | |
| }, | |
| "model_info": { | |
| "type": MODEL_CHOICE, | |
| "ar_en": "opus-mt-tc-big-ar-en", | |
| "en_ar": "opus-mt-tc-big-en-ar" | |
| if MODEL_CHOICE == "opus-big" | |
| else "marefa-mt-en-ar", | |
| }, | |
| } | |
| ) | |
| def transliterate(): | |
| """Main transliteration endpoint with Redis caching""" | |
| start_time = time.time() | |
| metrics["total_requests"] += 1 | |
| try: | |
| data = request.get_json() | |
| text = data.get("text", "").strip() | |
| from_lang = data.get("from", "en") | |
| to_lang = data.get("to", "ar") | |
| if not text or len(text) < 2: | |
| return jsonify({"variants": []}) | |
| # Redis cache key | |
| cache_key = f"{from_lang}-{to_lang}-{text}" | |
| # Check Redis cache | |
| cached = cache.get(cache_key) | |
| if cached: | |
| metrics["cache_hits"] += 1 | |
| return jsonify( | |
| {"variants": cached, "cached": True, "cache_source": "redis"} | |
| ) | |
| metrics["cache_misses"] += 1 | |
| metrics["model_requests"] += 1 | |
| # Generate variants with model | |
| variants = set([text, text.lower()]) | |
| if from_lang == "ar": | |
| # Normalize Arabic | |
| normalized = normalize_arabic(text) | |
| if normalized != text: | |
| variants.add(normalized) | |
| # Use model | |
| model_variants = transliterate_with_model(normalized, from_lang, to_lang) | |
| variants.update(model_variants) | |
| # Add phonetic variants | |
| for v in list(variants): | |
| if re.match(r"^[a-z]+$", v): | |
| variants.update(get_phonetic_variants(v)) | |
| else: # English to Arabic | |
| # Use model | |
| model_variants = transliterate_with_model(text, from_lang, to_lang) | |
| variants.update(model_variants) | |
| # Add Arabic variants for each result | |
| for v in list(variants): | |
| if re.search(r"[\u0600-\u06FF]", v): | |
| variants.update(add_arabic_variants(v)) | |
| # Filter and limit | |
| result = [v for v in variants if v and len(v) >= 2][:6] | |
| # Cache result in Redis | |
| cache.set(cache_key, result, timeout=CACHE_DEFAULT_TIMEOUT) | |
| # Update metrics | |
| response_time = (time.time() - start_time) * 1000 | |
| metrics["avg_response_time"] = ( | |
| metrics["avg_response_time"] * (metrics["total_requests"] - 1) | |
| + response_time | |
| ) / metrics["total_requests"] | |
| return jsonify( | |
| { | |
| "variants": result, | |
| "cached": False, | |
| "response_time_ms": round(response_time, 2), | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": str(e), "variants": []}), 500 | |
| def transliterate_batch(): | |
| """Batch transliteration with Redis caching""" | |
| start_time = time.time() | |
| try: | |
| data = request.get_json() | |
| texts = data.get("texts", []) | |
| from_lang = data.get("from", "en") | |
| to_lang = data.get("to", "ar") | |
| results = {} | |
| cache_hits = 0 | |
| cache_misses = 0 | |
| for text in texts: | |
| cache_key = f"{from_lang}-{to_lang}-{text}" | |
| cached = cache.get(cache_key) | |
| if cached: | |
| results[text] = cached | |
| metrics["cache_hits"] += 1 | |
| cache_hits += 1 | |
| else: | |
| metrics["cache_misses"] += 1 | |
| cache_misses += 1 | |
| variants = set([text, text.lower()]) | |
| # Use models | |
| model_variants = transliterate_with_model(text, from_lang, to_lang) | |
| variants.update(model_variants) | |
| # Add variants | |
| if to_lang == "ar": | |
| for v in list(variants): | |
| if re.search(r"[\u0600-\u06FF]", v): | |
| variants.update(add_arabic_variants(v)) | |
| else: | |
| for v in list(variants): | |
| if re.match(r"^[a-z]+$", v): | |
| variants.update(get_phonetic_variants(v)) | |
| result = [v for v in variants if v and len(v) >= 2][:6] | |
| results[text] = result | |
| # Cache in Redis | |
| cache.set(cache_key, result, timeout=CACHE_DEFAULT_TIMEOUT) | |
| metrics["total_requests"] += len(texts) | |
| response_time = (time.time() - start_time) * 1000 | |
| return jsonify( | |
| { | |
| "results": results, | |
| "count": len(results), | |
| "cache_hits": cache_hits, | |
| "cache_misses": cache_misses, | |
| "response_time_ms": round(response_time, 2), | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({"error": str(e), "results": {}}), 500 | |
| def clear_cache(): | |
| """Clear Redis cache""" | |
| try: | |
| cache.clear() | |
| metrics["cache_hits"] = 0 | |
| metrics["cache_misses"] = 0 | |
| return jsonify({"status": "cache cleared", "cache_type": "redis"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def warm_cache(): | |
| """Warm Redis cache with common terms""" | |
| data = request.get_json() | |
| terms = data.get("terms", []) | |
| warmed = 0 | |
| failed = 0 | |
| for term in terms: | |
| try: | |
| if re.search(r"[\u0600-\u06FF]", term): | |
| cache_key = f"ar-en-{term}" | |
| if not cache.get(cache_key): | |
| variants = transliterate_with_model(term, "ar", "en") | |
| cache.set(cache_key, variants, timeout=CACHE_DEFAULT_TIMEOUT) | |
| warmed += 1 | |
| else: | |
| cache_key = f"en-ar-{term}" | |
| if not cache.get(cache_key): | |
| variants = transliterate_with_model(term, "en", "ar") | |
| cache.set(cache_key, variants, timeout=CACHE_DEFAULT_TIMEOUT) | |
| warmed += 1 | |
| except Exception as e: | |
| print(f"Failed to warm cache for '{term}': {e}") | |
| failed += 1 | |
| return jsonify( | |
| {"status": "success", "warmed": warmed, "failed": failed, "cache_type": "redis"} | |
| ) | |
| def cache_info(): | |
| """Get Redis cache information""" | |
| try: | |
| if cache_config.get("CACHE_TYPE") == "RedisCache": | |
| redis_client = cache.cache._write_client | |
| info = redis_client.info() | |
| return jsonify( | |
| { | |
| "cache_type": "Redis", | |
| "redis_version": info.get("redis_version"), | |
| "used_memory_human": info.get("used_memory_human"), | |
| "connected_clients": info.get("connected_clients"), | |
| "total_commands_processed": info.get("total_commands_processed"), | |
| "keyspace_hits": info.get("keyspace_hits"), | |
| "keyspace_misses": info.get("keyspace_misses"), | |
| "keys": redis_client.dbsize(), | |
| } | |
| ) | |
| else: | |
| return jsonify({"cache_type": "SimpleCache", "message": "Not using Redis"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| print("🚀 Starting Flask Transliteration API...") | |
| print(f"📦 Model choice: {MODEL_CHOICE}") | |
| print("") | |
| # Load better models | |
| success = load_models() | |
| if success: | |
| print("\n✅ High-quality models loaded!") | |
| print(" AR→EN: opus-mt-tc-big-ar-en") | |
| if MODEL_CHOICE == "opus-big": | |
| print(" EN→AR: opus-mt-tc-big-en-ar (MUCH better than basic!)") | |
| else: | |
| print(" EN→AR: marefa-mt-en-ar (Arabic-specialized)") | |
| else: | |
| print("\n⚠️ Models failed to load") | |
| print("\n✅ Server starting on http://localhost:5000") | |
| print("📝 Endpoints:") | |
| print(" POST /transliterate") | |
| print(" POST /transliterate/batch") | |
| print(" GET /health") | |
| print(" GET /stats") | |
| print(" GET /cache/info") | |
| print(" POST /cache/clear") | |
| print(" POST /cache/warm") | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False, threaded=True) | |