abood-bak's picture
initial ai service-2
cc0dade
"""
Flask API with BETTER Translation Models + Redis Cache
Uses high-quality models for both directions with Redis caching
"""
from dotenv import load_dotenv
load_dotenv()
import os
import re
import time
import warnings
from functools import lru_cache
from typing import List
from flask import Flask, jsonify, request
from flask_caching import Cache
warnings.filterwarnings("ignore")
app = Flask(__name__)
# Redis Cache configuration
# REDIS_HOST = os.getenv("REDIS_HOST", "redis")
# REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
# REDIS_DB = int(os.getenv("REDIS_DB", 0))
# CACHE_DEFAULT_TIMEOUT = int(os.getenv("CACHE_DEFAULT_TIMEOUT", 3600))
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
CACHE_DEFAULT_TIMEOUT = int(os.getenv("CACHE_DEFAULT_TIMEOUT", 3600))
# cache_config = {
# "CACHE_TYPE": "RedisCache",
# "CACHE_REDIS_HOST": REDIS_HOST,
# "CACHE_REDIS_PORT": REDIS_PORT,
# "CACHE_REDIS_DB": REDIS_DB,
# "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT,
# "CACHE_KEY_PREFIX": "transliteration:",
# }
# print(f"📦 Redis Configuration:")
# print(f" Host: {REDIS_HOST}")
# print(f" Port: {REDIS_PORT}")
# print(f" DB: {REDIS_DB}")
# print(f" Timeout: {CACHE_DEFAULT_TIMEOUT}s")
cache_config = {
"CACHE_TYPE": "RedisCache",
"CACHE_REDIS_URL": REDIS_URL, # MUST be rediss://
"CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT,
"CACHE_KEY_PREFIX": "transliteration:",
}
print(f"📦 Redis Configuration:")
print(f" URL: {REDIS_URL.split('@')[-1]}") # Print only the endpoint for security
print(f" Timeout: {CACHE_DEFAULT_TIMEOUT}s")
# try:
# cache = Cache(app, config=cache_config)
# print("✅ Redis cache initialized successfully")
# except Exception as e:
# print(f"⚠️ Redis connection failed: {e}")
# print(" Falling back to SimpleCache")
# cache_config = {
# "CACHE_TYPE": "SimpleCache",
# "CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT,
# "CACHE_THRESHOLD": 10000,
# }
# cache = Cache(app, config=cache_config)
try:
cache = Cache(app, config=cache_config)
with app.app_context():
cache.set("ping", "pong", timeout=10)
if cache.get("ping") == "pong":
print("✅ Upstash Redis connected successfully")
else:
raise Exception("Ping test failed")
except Exception as e:
print(f"⚠️ Redis connection failed: {e}")
print(" Falling back to SimpleCache")
cache = Cache(
app,
config={
"CACHE_TYPE": "SimpleCache",
"CACHE_DEFAULT_TIMEOUT": CACHE_DEFAULT_TIMEOUT,
"CACHE_THRESHOLD": 10000,
},
)
# Global variables
ar_en_model = None
en_ar_model = None
ar_en_tokenizer = None
en_ar_tokenizer = None
device = None
USE_TRANSFORMERS = True
# Which models to use
MODEL_CHOICE = os.getenv("MODEL_CHOICE", "opus-big") # Options: "opus-big", "marefa"
# Performance metrics
metrics = {
"cache_hits": 0,
"cache_misses": 0,
"model_requests": 0,
"total_requests": 0,
"avg_response_time": 0.0,
}
print("🚀 Starting Transliteration API with Better Models + Redis Cache...")
def load_models():
"""Load BETTER translation models"""
global \
ar_en_model, \
en_ar_model, \
ar_en_tokenizer, \
en_ar_tokenizer, \
device, \
USE_TRANSFORMERS, \
MODEL_CHOICE
try:
import torch
from transformers import MarianMTModel, MarianTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️ Using device: {device}")
if MODEL_CHOICE == "opus-big":
# OPTION 1: OPUS Big Models (BEST QUALITY - Recommended)
print("📥 Loading OPUS-MT-BIG models (high quality)...")
# Arabic to English (OPUS Big)
print(" Loading AR→EN (opus-mt-tc-big-ar-en)...")
ar_en_model_name = "Helsinki-NLP/opus-mt-tc-big-ar-en"
ar_en_tokenizer = MarianTokenizer.from_pretrained(ar_en_model_name)
ar_en_model = MarianMTModel.from_pretrained(ar_en_model_name).to(device)
ar_en_model.eval()
print(" ✅ AR→EN loaded")
# English to Arabic (OPUS Big) - MUCH BETTER than basic model
print(" Loading EN→AR (opus-mt-tc-big-en-ar)...")
en_ar_model_name = "Helsinki-NLP/opus-mt-tc-big-en-ar"
en_ar_tokenizer = MarianTokenizer.from_pretrained(en_ar_model_name)
en_ar_model = MarianMTModel.from_pretrained(en_ar_model_name).to(device)
en_ar_model.eval()
print(" ✅ EN→AR loaded")
print("🎉 OPUS-MT-BIG models loaded successfully!")
elif MODEL_CHOICE == "marefa":
# OPTION 2: Marefa Model (Specialized for Arabic)
print("📥 Loading Marefa models (Arabic-specialized)...")
# Arabic to English (OPUS Big - still best for this direction)
print(" Loading AR→EN (opus-mt-tc-big-ar-en)...")
ar_en_model_name = "Helsinki-NLP/opus-mt-tc-big-ar-en"
ar_en_tokenizer = MarianTokenizer.from_pretrained(ar_en_model_name)
ar_en_model = MarianMTModel.from_pretrained(ar_en_model_name).to(device)
ar_en_model.eval()
print(" ✅ AR→EN loaded")
# English to Arabic (Marefa - Arabic specialized)
print(" Loading EN→AR (marefa-mt-en-ar)...")
en_ar_model_name = "marefa-nlp/marefa-mt-en-ar"
en_ar_tokenizer = MarianTokenizer.from_pretrained(en_ar_model_name)
en_ar_model = MarianMTModel.from_pretrained(en_ar_model_name).to(device)
en_ar_model.eval()
print(" ✅ EN→AR loaded (Marefa)")
print("🎉 Marefa models loaded successfully!")
USE_TRANSFORMERS = True
return True
except Exception as e:
print(f"⚠️ Error loading models: {str(e)}")
print("💡 Make sure you have enough memory and internet connection")
USE_TRANSFORMERS = False
return False
@lru_cache(maxsize=10000)
def normalize_arabic(text: str) -> str:
"""Normalize Arabic text (cached in memory)"""
text = re.sub(r"[\u064B-\u065F]", "", text)
text = text.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا").replace("ٱ", "ا")
text = text.replace("ى", "ي")
text = text.replace("ة", "ه")
text = text.replace("ؤ", "و")
text = text.replace("ئ", "ي")
return text.strip()
@lru_cache(maxsize=10000)
def get_phonetic_variants(text: str) -> List[str]:
"""Get phonetic variants (cached in memory)"""
variants = {text.lower()}
base = text.lower()
patterns = [
(r"a", "e"),
(r"e", "a"),
(r"een$", "ain"),
(r"ain$", "een"),
(r"(.)\1", r"\1"),
]
for pattern, replacement in patterns:
if re.search(pattern, base):
variant = re.sub(pattern, replacement, base)
if variant != base and len(variant) >= 2:
variants.add(variant)
return list(variants)[:5]
def transliterate_with_model(text: str, from_lang: str, to_lang: str) -> List[str]:
"""Use high-quality transformer models"""
global ar_en_model, en_ar_model, ar_en_tokenizer, en_ar_tokenizer, device
if not USE_TRANSFORMERS:
return []
try:
import torch
# Select model and tokenizer
if from_lang == "ar" and to_lang == "en":
model = ar_en_model
tokenizer = ar_en_tokenizer
elif from_lang == "en" and to_lang == "ar":
model = en_ar_model
tokenizer = en_ar_tokenizer
# For OPUS Big EN→AR, need to add language token
if MODEL_CHOICE == "opus-big":
text = ">>ara<< " + text
else:
return []
# Tokenize
inputs = tokenizer([text], return_tensors="pt", padding=True).to(device)
variants = set()
with torch.no_grad():
# Method 1: Greedy decoding
outputs = model.generate(
**inputs, max_length=50, num_beams=1, do_sample=False
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
if result:
variants.add(result.lower() if to_lang == "en" else result)
# Method 2: Beam search (best quality)
outputs = model.generate(
**inputs,
max_length=50,
num_beams=5,
num_return_sequences=3,
do_sample=False,
early_stopping=True,
)
for output in outputs:
result = tokenizer.decode(output, skip_special_tokens=True).strip()
if result:
variants.add(result.lower() if to_lang == "en" else result)
# Method 3: Diverse beam search
outputs = model.generate(
**inputs,
max_length=50,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8,
num_return_sequences=2,
)
for output in outputs:
result = tokenizer.decode(output, skip_special_tokens=True).strip()
if result:
variants.add(result.lower() if to_lang == "en" else result)
return list(variants)[:6]
except Exception as e:
print(f"Model error: {str(e)}")
return []
def add_arabic_variants(text: str) -> List[str]:
"""Add common Arabic variants"""
variants = {text}
# Hamza variants
if text.startswith("ا"):
variants.add("أ" + text[1:])
variants.add("إ" + text[1:])
# Final ya
if text.endswith("ي"):
variants.add(text[:-1] + "ى")
# Taa marbuta
if text.endswith("ه"):
variants.add(text[:-1] + "ة")
return list(variants)
# ============================================================================
# API ENDPOINTS
# ============================================================================
@app.route("/health", methods=["GET"])
def health_check():
"""Health check with Redis status"""
redis_status = "connected"
try:
# Test Redis connection
cache.set("health_check", "ok", timeout=5)
test_val = cache.get("health_check")
if test_val != "ok":
redis_status = "error"
except Exception as e:
redis_status = f"error: {str(e)}"
return jsonify(
{
"status": "healthy",
"models_loaded": ar_en_model is not None and en_ar_model is not None,
"model_type": MODEL_CHOICE,
"using_transformers": USE_TRANSFORMERS,
"device": str(device) if device else "none",
"ar_en_model": "opus-mt-tc-big-ar-en",
"en_ar_model": "opus-mt-tc-big-en-ar"
if MODEL_CHOICE == "opus-big"
else "marefa-mt-en-ar",
"cache": {
"type": "Redis",
"status": redis_status,
"timeout": CACHE_DEFAULT_TIMEOUT,
},
"metrics": metrics,
}
)
@app.route("/stats", methods=["GET"])
def get_stats():
"""Performance statistics"""
cache_total = metrics["cache_hits"] + metrics["cache_misses"]
hit_rate = metrics["cache_hits"] / cache_total if cache_total > 0 else 0
# Get Redis info if available
redis_info = {}
try:
if cache_config.get("CACHE_TYPE") == "RedisCache":
redis_client = cache.cache._write_client
info = redis_client.info("stats")
redis_info = {
"total_connections_received": info.get("total_connections_received", 0),
"total_commands_processed": info.get("total_commands_processed", 0),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
}
except Exception as e:
redis_info = {"error": str(e)}
return jsonify(
{
"cache": {
"hits": metrics["cache_hits"],
"misses": metrics["cache_misses"],
"hit_rate": f"{hit_rate * 100:.2f}%",
"redis": redis_info,
},
"requests": {
"total": metrics["total_requests"],
"model_requests": metrics["model_requests"],
},
"performance": {
"avg_response_time_ms": f"{metrics['avg_response_time']:.2f}",
},
"model_info": {
"type": MODEL_CHOICE,
"ar_en": "opus-mt-tc-big-ar-en",
"en_ar": "opus-mt-tc-big-en-ar"
if MODEL_CHOICE == "opus-big"
else "marefa-mt-en-ar",
},
}
)
@app.route("/transliterate", methods=["POST"])
def transliterate():
"""Main transliteration endpoint with Redis caching"""
start_time = time.time()
metrics["total_requests"] += 1
try:
data = request.get_json()
text = data.get("text", "").strip()
from_lang = data.get("from", "en")
to_lang = data.get("to", "ar")
if not text or len(text) < 2:
return jsonify({"variants": []})
# Redis cache key
cache_key = f"{from_lang}-{to_lang}-{text}"
# Check Redis cache
cached = cache.get(cache_key)
if cached:
metrics["cache_hits"] += 1
return jsonify(
{"variants": cached, "cached": True, "cache_source": "redis"}
)
metrics["cache_misses"] += 1
metrics["model_requests"] += 1
# Generate variants with model
variants = set([text, text.lower()])
if from_lang == "ar":
# Normalize Arabic
normalized = normalize_arabic(text)
if normalized != text:
variants.add(normalized)
# Use model
model_variants = transliterate_with_model(normalized, from_lang, to_lang)
variants.update(model_variants)
# Add phonetic variants
for v in list(variants):
if re.match(r"^[a-z]+$", v):
variants.update(get_phonetic_variants(v))
else: # English to Arabic
# Use model
model_variants = transliterate_with_model(text, from_lang, to_lang)
variants.update(model_variants)
# Add Arabic variants for each result
for v in list(variants):
if re.search(r"[\u0600-\u06FF]", v):
variants.update(add_arabic_variants(v))
# Filter and limit
result = [v for v in variants if v and len(v) >= 2][:6]
# Cache result in Redis
cache.set(cache_key, result, timeout=CACHE_DEFAULT_TIMEOUT)
# Update metrics
response_time = (time.time() - start_time) * 1000
metrics["avg_response_time"] = (
metrics["avg_response_time"] * (metrics["total_requests"] - 1)
+ response_time
) / metrics["total_requests"]
return jsonify(
{
"variants": result,
"cached": False,
"response_time_ms": round(response_time, 2),
}
)
except Exception as e:
return jsonify({"error": str(e), "variants": []}), 500
@app.route("/transliterate/batch", methods=["POST"])
def transliterate_batch():
"""Batch transliteration with Redis caching"""
start_time = time.time()
try:
data = request.get_json()
texts = data.get("texts", [])
from_lang = data.get("from", "en")
to_lang = data.get("to", "ar")
results = {}
cache_hits = 0
cache_misses = 0
for text in texts:
cache_key = f"{from_lang}-{to_lang}-{text}"
cached = cache.get(cache_key)
if cached:
results[text] = cached
metrics["cache_hits"] += 1
cache_hits += 1
else:
metrics["cache_misses"] += 1
cache_misses += 1
variants = set([text, text.lower()])
# Use models
model_variants = transliterate_with_model(text, from_lang, to_lang)
variants.update(model_variants)
# Add variants
if to_lang == "ar":
for v in list(variants):
if re.search(r"[\u0600-\u06FF]", v):
variants.update(add_arabic_variants(v))
else:
for v in list(variants):
if re.match(r"^[a-z]+$", v):
variants.update(get_phonetic_variants(v))
result = [v for v in variants if v and len(v) >= 2][:6]
results[text] = result
# Cache in Redis
cache.set(cache_key, result, timeout=CACHE_DEFAULT_TIMEOUT)
metrics["total_requests"] += len(texts)
response_time = (time.time() - start_time) * 1000
return jsonify(
{
"results": results,
"count": len(results),
"cache_hits": cache_hits,
"cache_misses": cache_misses,
"response_time_ms": round(response_time, 2),
}
)
except Exception as e:
return jsonify({"error": str(e), "results": {}}), 500
@app.route("/cache/clear", methods=["POST"])
def clear_cache():
"""Clear Redis cache"""
try:
cache.clear()
metrics["cache_hits"] = 0
metrics["cache_misses"] = 0
return jsonify({"status": "cache cleared", "cache_type": "redis"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/cache/warm", methods=["POST"])
def warm_cache():
"""Warm Redis cache with common terms"""
data = request.get_json()
terms = data.get("terms", [])
warmed = 0
failed = 0
for term in terms:
try:
if re.search(r"[\u0600-\u06FF]", term):
cache_key = f"ar-en-{term}"
if not cache.get(cache_key):
variants = transliterate_with_model(term, "ar", "en")
cache.set(cache_key, variants, timeout=CACHE_DEFAULT_TIMEOUT)
warmed += 1
else:
cache_key = f"en-ar-{term}"
if not cache.get(cache_key):
variants = transliterate_with_model(term, "en", "ar")
cache.set(cache_key, variants, timeout=CACHE_DEFAULT_TIMEOUT)
warmed += 1
except Exception as e:
print(f"Failed to warm cache for '{term}': {e}")
failed += 1
return jsonify(
{"status": "success", "warmed": warmed, "failed": failed, "cache_type": "redis"}
)
@app.route("/cache/info", methods=["GET"])
def cache_info():
"""Get Redis cache information"""
try:
if cache_config.get("CACHE_TYPE") == "RedisCache":
redis_client = cache.cache._write_client
info = redis_client.info()
return jsonify(
{
"cache_type": "Redis",
"redis_version": info.get("redis_version"),
"used_memory_human": info.get("used_memory_human"),
"connected_clients": info.get("connected_clients"),
"total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"keys": redis_client.dbsize(),
}
)
else:
return jsonify({"cache_type": "SimpleCache", "message": "Not using Redis"})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
print("🚀 Starting Flask Transliteration API...")
print(f"📦 Model choice: {MODEL_CHOICE}")
print("")
# Load better models
success = load_models()
if success:
print("\n✅ High-quality models loaded!")
print(" AR→EN: opus-mt-tc-big-ar-en")
if MODEL_CHOICE == "opus-big":
print(" EN→AR: opus-mt-tc-big-en-ar (MUCH better than basic!)")
else:
print(" EN→AR: marefa-mt-en-ar (Arabic-specialized)")
else:
print("\n⚠️ Models failed to load")
print("\n✅ Server starting on http://localhost:5000")
print("📝 Endpoints:")
print(" POST /transliterate")
print(" POST /transliterate/batch")
print(" GET /health")
print(" GET /stats")
print(" GET /cache/info")
print(" POST /cache/clear")
print(" POST /cache/warm")
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False, threaded=True)