|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
import asyncio |
|
|
import aiohttp |
|
|
import torch |
|
|
import json |
|
|
import logging |
|
|
import threading |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional, Any, Callable |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TranslationRequest(BaseModel): |
|
|
text: str |
|
|
source_lang: str |
|
|
target_lang: str |
|
|
auto_charge: bool = False |
|
|
|
|
|
|
|
|
class JobStatus(Enum): |
|
|
PENDING = "pending" |
|
|
ASSIGNED = "assigned" |
|
|
PROCESSING = "processing" |
|
|
COMPLETED = "completed" |
|
|
FAILED = "failed" |
|
|
CANCELLED = "cancelled" |
|
|
|
|
|
class ServerStatus(Enum): |
|
|
AVAILABLE = "available" |
|
|
BUSY = "busy" |
|
|
OFFLINE = "offline" |
|
|
ERROR = "error" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TranslationJob: |
|
|
job_id: str |
|
|
request_id: str |
|
|
text: str |
|
|
source_lang: str |
|
|
target_lang: str |
|
|
priority: int = 0 |
|
|
auto_charge: bool = False |
|
|
notification_url: Optional[str] = None |
|
|
created_at: float = field(default_factory=time.time) |
|
|
assigned_at: Optional[float] = None |
|
|
started_at: Optional[float] = None |
|
|
completed_at: Optional[float] = None |
|
|
assigned_server: Optional[str] = None |
|
|
status: JobStatus = JobStatus.PENDING |
|
|
result: Optional[Dict[str, Any]] = None |
|
|
error: Optional[str] = None |
|
|
retry_count: int = 0 |
|
|
max_retries: int = 3 |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
|
class ServerInfo: |
|
|
id: str |
|
|
url: str |
|
|
status: ServerStatus = ServerStatus.OFFLINE |
|
|
last_ping: float = 0 |
|
|
current_jobs: int = 0 |
|
|
max_concurrent_jobs: int = 1 |
|
|
response_time: float = 0 |
|
|
error_count: int = 0 |
|
|
total_requests: int = 0 |
|
|
last_error: Optional[str] = None |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
class ServerRegistry: |
|
|
def __init__(self, health_check_interval: int = 30): |
|
|
self.servers: Dict[str, ServerInfo] = {} |
|
|
self.health_check_interval = health_check_interval |
|
|
self.lock = threading.Lock() |
|
|
self.health_monitor_task = None |
|
|
self.running = False |
|
|
|
|
|
def register_server(self, server_id: str, url: str, max_concurrent_jobs: int = 1): |
|
|
"""Register a new translation server""" |
|
|
with self.lock: |
|
|
self.servers[server_id] = ServerInfo( |
|
|
id=server_id, |
|
|
url=url, |
|
|
max_concurrent_jobs=max_concurrent_jobs |
|
|
) |
|
|
logger.info(f"Registered server {server_id} at {url}") |
|
|
|
|
|
def unregister_server(self, server_id: str): |
|
|
"""Remove a server from registry""" |
|
|
with self.lock: |
|
|
if server_id in self.servers: |
|
|
del self.servers[server_id] |
|
|
logger.info(f"Unregistered server {server_id}") |
|
|
|
|
|
def get_available_server(self) -> Optional[ServerInfo]: |
|
|
"""Get the best available server for processing""" |
|
|
with self.lock: |
|
|
available_servers = [ |
|
|
server for server in self.servers.values() |
|
|
if server.status == ServerStatus.AVAILABLE and |
|
|
server.current_jobs < server.max_concurrent_jobs |
|
|
] |
|
|
|
|
|
if not available_servers: |
|
|
return None |
|
|
|
|
|
available_servers.sort(key=lambda s: (s.current_jobs, s.response_time)) |
|
|
return available_servers[0] |
|
|
|
|
|
def mark_server_busy(self, server_id: str): |
|
|
"""Mark server as busy""" |
|
|
with self.lock: |
|
|
if server_id in self.servers: |
|
|
self.servers[server_id].current_jobs += 1 |
|
|
if self.servers[server_id].current_jobs >= self.servers[server_id].max_concurrent_jobs: |
|
|
self.servers[server_id].status = ServerStatus.BUSY |
|
|
|
|
|
def mark_server_available(self, server_id: str): |
|
|
"""Mark server as available""" |
|
|
with self.lock: |
|
|
if server_id in self.servers: |
|
|
self.servers[server_id].current_jobs = max(0, self.servers[server_id].current_jobs - 1) |
|
|
if self.servers[server_id].current_jobs < self.servers[server_id].max_concurrent_jobs: |
|
|
self.servers[server_id].status = ServerStatus.AVAILABLE |
|
|
|
|
|
def get_server_stats(self) -> Dict[str, Any]: |
|
|
"""Get statistics about all servers""" |
|
|
with self.lock: |
|
|
stats = { |
|
|
'total_servers': len(self.servers), |
|
|
'available_servers': len([s for s in self.servers.values() if s.status == ServerStatus.AVAILABLE]), |
|
|
'busy_servers': len([s for s in self.servers.values() if s.status == ServerStatus.BUSY]), |
|
|
'offline_servers': len([s for s in self.servers.values() if s.status == ServerStatus.OFFLINE]), |
|
|
'servers': { |
|
|
server_id: { |
|
|
'status': server.status.value, |
|
|
'current_jobs': server.current_jobs, |
|
|
'max_jobs': server.max_concurrent_jobs, |
|
|
'response_time': server.response_time, |
|
|
'total_requests': server.total_requests, |
|
|
'error_count': server.error_count, |
|
|
'last_ping': server.last_ping |
|
|
} |
|
|
for server_id, server in self.servers.items() |
|
|
} |
|
|
} |
|
|
return stats |
|
|
|
|
|
async def check_server_health(self, server: ServerInfo) -> bool: |
|
|
"""Check if a server is healthy""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session: |
|
|
async with session.get(f"{server.url}/api/health") as response: |
|
|
response_time = time.time() - start_time |
|
|
|
|
|
if response.status == 200: |
|
|
data = await response.json() |
|
|
with self.lock: |
|
|
server.last_ping = time.time() |
|
|
server.response_time = response_time |
|
|
server.error_count = 0 |
|
|
server.last_error = None |
|
|
|
|
|
if data.get('status') == 'healthy': |
|
|
if server.current_jobs < server.max_concurrent_jobs: |
|
|
server.status = ServerStatus.AVAILABLE |
|
|
else: |
|
|
server.status = ServerStatus.BUSY |
|
|
else: |
|
|
server.status = ServerStatus.ERROR |
|
|
|
|
|
return True |
|
|
else: |
|
|
raise Exception(f"HTTP {response.status}") |
|
|
|
|
|
except Exception as e: |
|
|
with self.lock: |
|
|
server.status = ServerStatus.OFFLINE |
|
|
server.error_count += 1 |
|
|
server.last_error = str(e) |
|
|
logger.error(f"Health check failed for server {server.id}: {e}") |
|
|
return False |
|
|
|
|
|
async def health_monitor(self): |
|
|
"""Continuously monitor server health""" |
|
|
while self.running: |
|
|
try: |
|
|
servers_to_check = list(self.servers.values()) |
|
|
|
|
|
health_tasks = [ |
|
|
self.check_server_health(server) |
|
|
for server in servers_to_check |
|
|
] |
|
|
|
|
|
await asyncio.gather(*health_tasks, return_exceptions=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in health monitor: {e}") |
|
|
|
|
|
await asyncio.sleep(self.health_check_interval) |
|
|
|
|
|
def start_health_monitoring(self): |
|
|
"""Start the health monitoring task""" |
|
|
if not self.running: |
|
|
self.running = True |
|
|
loop = asyncio.get_event_loop() |
|
|
self.health_monitor_task = loop.create_task(self.health_monitor()) |
|
|
logger.info("Started server health monitoring") |
|
|
|
|
|
def stop_health_monitoring(self): |
|
|
"""Stop the health monitoring task""" |
|
|
self.running = False |
|
|
if self.health_monitor_task: |
|
|
self.health_monitor_task.cancel() |
|
|
logger.info("Stopped server health monitoring") |
|
|
|
|
|
|
|
|
class TranslationQueue: |
|
|
def __init__(self, max_queue_size: int = 1000): |
|
|
self.pending_jobs: asyncio.Queue = asyncio.Queue(maxsize=max_queue_size) |
|
|
self.active_jobs: Dict[str, TranslationJob] = {} |
|
|
self.completed_jobs: Dict[str, TranslationJob] = {} |
|
|
self.failed_jobs: Dict[str, TranslationJob] = {} |
|
|
|
|
|
self.lock = asyncio.Lock() |
|
|
self.processor_task: Optional[asyncio.Task] = None |
|
|
self.running = False |
|
|
|
|
|
self.total_jobs = 0 |
|
|
self.processed_jobs = 0 |
|
|
self.failed_job_count = 0 |
|
|
|
|
|
async def add_job(self, |
|
|
text: str, |
|
|
source_lang: str, |
|
|
target_lang: str, |
|
|
request_id: Optional[str] = None, |
|
|
priority: int = 0, |
|
|
auto_charge: bool = False, |
|
|
notification_url: Optional[str] = None) -> str: |
|
|
"""Add a new translation job to the queue""" |
|
|
|
|
|
if not request_id: |
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
job_id = f"job_{int(time.time())}_{str(uuid.uuid4())[:8]}" |
|
|
|
|
|
job = TranslationJob( |
|
|
job_id=job_id, |
|
|
request_id=request_id, |
|
|
text=text, |
|
|
source_lang=source_lang, |
|
|
target_lang=target_lang, |
|
|
priority=priority, |
|
|
auto_charge=auto_charge, |
|
|
notification_url=notification_url |
|
|
) |
|
|
|
|
|
try: |
|
|
await self.pending_jobs.put(job) |
|
|
|
|
|
async with self.lock: |
|
|
self.total_jobs += 1 |
|
|
|
|
|
logger.info(f"Added job {job_id} to queue (request_id: {request_id})") |
|
|
return job_id |
|
|
|
|
|
except asyncio.QueueFull: |
|
|
logger.error(f"Queue is full, cannot add job {job_id}") |
|
|
raise Exception("Translation queue is full, please try again later") |
|
|
|
|
|
async def get_job_status(self, job_id: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get the status of a specific job""" |
|
|
async with self.lock: |
|
|
if job_id in self.active_jobs: |
|
|
job = self.active_jobs[job_id] |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"request_id": job.request_id, |
|
|
"status": job.status.value, |
|
|
"assigned_server": job.assigned_server, |
|
|
"created_at": job.created_at, |
|
|
"assigned_at": job.assigned_at, |
|
|
"started_at": job.started_at, |
|
|
"processing_time": time.time() - job.started_at if job.started_at else 0, |
|
|
"retry_count": job.retry_count |
|
|
} |
|
|
|
|
|
if job_id in self.completed_jobs: |
|
|
job = self.completed_jobs[job_id] |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"request_id": job.request_id, |
|
|
"status": job.status.value, |
|
|
"assigned_server": job.assigned_server, |
|
|
"created_at": job.created_at, |
|
|
"completed_at": job.completed_at, |
|
|
"processing_time": job.completed_at - job.started_at if job.started_at and job.completed_at else 0, |
|
|
"result": job.result, |
|
|
"retry_count": job.retry_count |
|
|
} |
|
|
|
|
|
if job_id in self.failed_jobs: |
|
|
job = self.failed_jobs[job_id] |
|
|
return { |
|
|
"job_id": job_id, |
|
|
"request_id": job.request_id, |
|
|
"status": job.status.value, |
|
|
"error": job.error, |
|
|
"created_at": job.created_at, |
|
|
"failed_at": job.completed_at, |
|
|
"retry_count": job.retry_count |
|
|
} |
|
|
|
|
|
return None |
|
|
|
|
|
async def get_job_by_request_id(self, request_id: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get job status by request_id""" |
|
|
async with self.lock: |
|
|
all_jobs = {**self.active_jobs, **self.completed_jobs, **self.failed_jobs} |
|
|
|
|
|
for job in all_jobs.values(): |
|
|
if job.request_id == request_id: |
|
|
return await self.get_job_status(job.job_id) |
|
|
|
|
|
return None |
|
|
|
|
|
async def cancel_job(self, job_id: str) -> bool: |
|
|
"""Cancel a pending or active job""" |
|
|
async with self.lock: |
|
|
if job_id in self.active_jobs: |
|
|
job = self.active_jobs[job_id] |
|
|
if job.status in [JobStatus.PENDING, JobStatus.ASSIGNED]: |
|
|
job.status = JobStatus.CANCELLED |
|
|
job.completed_at = time.time() |
|
|
|
|
|
self.failed_jobs[job_id] = job |
|
|
del self.active_jobs[job_id] |
|
|
|
|
|
if job.assigned_server: |
|
|
server_registry.mark_server_available(job.assigned_server) |
|
|
|
|
|
logger.info(f"Cancelled job {job_id}") |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
async def get_queue_stats(self) -> Dict[str, Any]: |
|
|
"""Get queue statistics""" |
|
|
async with self.lock: |
|
|
pending_count = self.pending_jobs.qsize() |
|
|
active_count = len(self.active_jobs) |
|
|
completed_count = len(self.completed_jobs) |
|
|
failed_count = len(self.failed_jobs) |
|
|
|
|
|
return { |
|
|
"pending_jobs": pending_count, |
|
|
"active_jobs": active_count, |
|
|
"completed_jobs": completed_count, |
|
|
"failed_jobs": failed_count, |
|
|
"total_jobs": self.total_jobs, |
|
|
"processed_jobs": self.processed_jobs, |
|
|
"success_rate": (self.processed_jobs / max(1, self.total_jobs)) * 100, |
|
|
"queue_utilization": (pending_count / self.pending_jobs.maxsize) * 100 |
|
|
} |
|
|
|
|
|
async def send_translation_request(self, server_url: str, job: TranslationJob) -> Dict[str, Any]: |
|
|
"""Send translation request to a specific server""" |
|
|
try: |
|
|
payload = { |
|
|
"text": job.text, |
|
|
"source_lang": job.source_lang, |
|
|
"target_lang": job.target_lang, |
|
|
"request_id": job.request_id, |
|
|
"auto_charge": job.auto_charge, |
|
|
"notification_url": job.notification_url |
|
|
} |
|
|
|
|
|
timeout = aiohttp.ClientTimeout(total=300) |
|
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
async with session.post( |
|
|
f"{server_url}/api/translate/heavy", |
|
|
json=payload, |
|
|
headers={"Content-Type": "application/json"} |
|
|
) as response: |
|
|
|
|
|
if response.status == 200: |
|
|
result = await response.json() |
|
|
logger.info(f"Successfully submitted job {job.job_id} to server {server_url}") |
|
|
return result |
|
|
else: |
|
|
error_text = await response.text() |
|
|
raise Exception(f"Server returned {response.status}: {error_text}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to send job {job.job_id} to server {server_url}: {e}") |
|
|
raise e |
|
|
|
|
|
async def process_queue(self): |
|
|
"""Main queue processor - assigns jobs to available servers""" |
|
|
logger.info("Started queue processor") |
|
|
|
|
|
while self.running: |
|
|
try: |
|
|
try: |
|
|
job = await asyncio.wait_for(self.pending_jobs.get(), timeout=1.0) |
|
|
except asyncio.TimeoutError: |
|
|
continue |
|
|
|
|
|
available_server = server_registry.get_available_server() |
|
|
|
|
|
if not available_server: |
|
|
await self.pending_jobs.put(job) |
|
|
logger.warning(f"No available servers for job {job.job_id}, requeueing") |
|
|
await asyncio.sleep(2) |
|
|
continue |
|
|
|
|
|
async with self.lock: |
|
|
job.assigned_server = available_server.id |
|
|
job.assigned_at = time.time() |
|
|
job.status = JobStatus.ASSIGNED |
|
|
self.active_jobs[job.job_id] = job |
|
|
|
|
|
server_registry.mark_server_busy(available_server.id) |
|
|
|
|
|
try: |
|
|
job.status = JobStatus.PROCESSING |
|
|
job.started_at = time.time() |
|
|
|
|
|
result = await self.send_translation_request(available_server.url, job) |
|
|
|
|
|
logger.info(f"Job {job.job_id} submitted to server {available_server.id}") |
|
|
|
|
|
except Exception as e: |
|
|
async with self.lock: |
|
|
job.retry_count += 1 |
|
|
job.error = str(e) |
|
|
|
|
|
if job.retry_count < job.max_retries: |
|
|
job.status = JobStatus.PENDING |
|
|
job.assigned_server = None |
|
|
job.assigned_at = None |
|
|
job.started_at = None |
|
|
|
|
|
await self.pending_jobs.put(job) |
|
|
del self.active_jobs[job.job_id] |
|
|
|
|
|
logger.warning(f"Job {job.job_id} failed, retrying ({job.retry_count}/{job.max_retries})") |
|
|
else: |
|
|
job.status = JobStatus.FAILED |
|
|
job.completed_at = time.time() |
|
|
|
|
|
self.failed_jobs[job.job_id] = job |
|
|
self.failed_job_count += 1 |
|
|
del self.active_jobs[job.job_id] |
|
|
|
|
|
logger.error(f"Job {job.job_id} failed permanently after {job.retry_count} retries") |
|
|
|
|
|
server_registry.mark_server_available(available_server.id) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in queue processor: {e}") |
|
|
await asyncio.sleep(1) |
|
|
|
|
|
def start_processing(self): |
|
|
"""Start the queue processor""" |
|
|
if not self.running: |
|
|
self.running = True |
|
|
self.processor_task = asyncio.create_task(self.process_queue()) |
|
|
logger.info("Started queue processing") |
|
|
|
|
|
def stop_processing(self): |
|
|
"""Stop the queue processor""" |
|
|
self.running = False |
|
|
if self.processor_task: |
|
|
self.processor_task.cancel() |
|
|
logger.info("Stopped queue processing") |
|
|
|
|
|
|
|
|
server_registry = ServerRegistry() |
|
|
translation_queue = TranslationQueue() |
|
|
|
|
|
|
|
|
LOAD_BALANCER_ENABLED = os.getenv("LOAD_BALANCER_ENABLED", "false").lower() == "true" |
|
|
SERVER_ID = os.getenv("SERVER_ID", f"server_{int(time.time())}") |
|
|
CURRENT_SERVER_URL = os.getenv("CURRENT_SERVER_URL", "http://localhost:7860") |
|
|
PEER_SERVERS = os.getenv("PEER_SERVERS", "").split(",") if os.getenv("PEER_SERVERS") else [] |
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "default_model") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Enhanced Translation Service with Load Balancer") |
|
|
|
|
|
|
|
|
translations = {} |
|
|
translator = None |
|
|
|
|
|
|
|
|
async def estimate_queue_wait_time() -> int: |
|
|
"""Estimate wait time in seconds based on queue size and server availability""" |
|
|
try: |
|
|
queue_stats = await translation_queue.get_queue_stats() |
|
|
server_stats = server_registry.get_server_stats() |
|
|
|
|
|
pending_jobs = queue_stats['pending_jobs'] |
|
|
available_servers = server_stats['available_servers'] |
|
|
|
|
|
if available_servers == 0: |
|
|
return 300 |
|
|
|
|
|
estimated_seconds = (pending_jobs * 30) // max(1, available_servers) |
|
|
return min(estimated_seconds, 1800) |
|
|
|
|
|
except Exception: |
|
|
return 120 |
|
|
|
|
|
async def send_completion_notification(notification_url: str, request_id: str, |
|
|
translated_text: str, result: dict, |
|
|
character_count: int, translation_length: int, |
|
|
source_lang: str, target_lang: str, auto_charge: bool): |
|
|
"""Send completion notification with enhanced data""" |
|
|
try: |
|
|
payload = { |
|
|
"request_id": request_id, |
|
|
"status": "completed", |
|
|
"translated_text": translated_text, |
|
|
"processing_time": result['processing_time'], |
|
|
"character_count": character_count, |
|
|
"translation_length": translation_length, |
|
|
"source_lang": source_lang, |
|
|
"target_lang": target_lang, |
|
|
"from_cache": result.get('from_cache', False), |
|
|
"chunks_count": result.get('chunks_count', 1), |
|
|
"auto_charge": auto_charge, |
|
|
"server_id": SERVER_ID, |
|
|
"completed_at": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
timeout = aiohttp.ClientTimeout(total=45) |
|
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
async with session.post( |
|
|
notification_url, |
|
|
json=payload, |
|
|
headers={ |
|
|
'Content-Type': 'application/json', |
|
|
'User-Agent': 'MLT-Server/2.0' |
|
|
} |
|
|
) as response: |
|
|
|
|
|
if response.status == 200: |
|
|
logger.info(f"Notification sent successfully for {request_id}") |
|
|
return True |
|
|
else: |
|
|
logger.warning(f"Notification failed with status {response.status} for {request_id}") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to send notification for {request_id}: {e}") |
|
|
return False |
|
|
|
|
|
async def run_enhanced_translation_job(request_id: str, text: str, source_lang: str, |
|
|
target_lang: str, notification_url: Optional[str], |
|
|
auto_charge: bool = False): |
|
|
"""Enhanced translation job runner with load balancer integration""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
for i in range(1, 10): |
|
|
await asyncio.sleep(2) |
|
|
if request_id in translations: |
|
|
translations[request_id]["progress"] = i * 10 |
|
|
translations[request_id]["elapsed_time"] = time.time() - start_time |
|
|
|
|
|
|
|
|
result = translator.translate_text(text, source_lang, target_lang) |
|
|
|
|
|
translated_text = result['translated_text'] |
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
|
|
|
translations[request_id] = { |
|
|
"status": "completed", |
|
|
"progress": 100, |
|
|
"elapsed_time": processing_time, |
|
|
"message": "Translation completed successfully", |
|
|
"result": translated_text, |
|
|
"server_id": SERVER_ID, |
|
|
"processing_time": result['processing_time'], |
|
|
"from_cache": result.get('from_cache', False) |
|
|
} |
|
|
|
|
|
|
|
|
translator.completed_translations[request_id] = { |
|
|
'result': result, |
|
|
'completed_at': time.time(), |
|
|
'character_count': len(text), |
|
|
'translation_length': len(translated_text), |
|
|
'server_id': SERVER_ID |
|
|
} |
|
|
|
|
|
|
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_registry.mark_server_available(SERVER_ID) |
|
|
|
|
|
|
|
|
if notification_url: |
|
|
await send_completion_notification( |
|
|
notification_url, request_id, translated_text, result, |
|
|
len(text), len(translated_text), source_lang, target_lang, auto_charge |
|
|
) |
|
|
|
|
|
logger.info(f"Translation job {request_id} completed successfully on server {SERVER_ID}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in translation job {request_id}: {e}") |
|
|
|
|
|
|
|
|
if request_id in translations: |
|
|
translations[request_id] = { |
|
|
"status": "failed", |
|
|
"message": f"Translation failed: {str(e)}", |
|
|
"server_id": SERVER_ID, |
|
|
"elapsed_time": time.time() - start_time if 'start_time' in locals() else 0 |
|
|
} |
|
|
|
|
|
|
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_registry.mark_server_available(SERVER_ID) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize load balancer on startup""" |
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_registry.register_server(SERVER_ID, CURRENT_SERVER_URL, max_concurrent_jobs=1) |
|
|
|
|
|
for i, peer_url in enumerate(PEER_SERVERS): |
|
|
if peer_url.strip(): |
|
|
peer_id = f"peer_server_{i}" |
|
|
server_registry.register_server(peer_id, peer_url.strip(), max_concurrent_jobs=1) |
|
|
|
|
|
server_registry.start_health_monitoring() |
|
|
translation_queue.start_processing() |
|
|
|
|
|
logger.info(f"Load balancer initialized with {len(PEER_SERVERS)} peer servers") |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
"""Cleanup load balancer on shutdown""" |
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_registry.stop_health_monitoring() |
|
|
translation_queue.stop_processing() |
|
|
logger.info("Load balancer shutdown complete") |
|
|
|
|
|
|
|
|
@app.post("/api/translate/heavy") |
|
|
async def heavy_translate_enhanced(request: Request): |
|
|
"""Enhanced heavy translation with load balancer support""" |
|
|
try: |
|
|
data = await request.json() |
|
|
|
|
|
|
|
|
request_id = data.get("request_id") |
|
|
if not request_id: |
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
text = data.get("text") |
|
|
source_lang = data.get("source_lang") |
|
|
target_lang = data.get("target_lang") |
|
|
auto_charge = data.get("auto_charge", False) |
|
|
notification_url = data.get("notification_url") |
|
|
|
|
|
|
|
|
if not all([text, source_lang, target_lang]): |
|
|
raise HTTPException(status_code=400, detail="Missing required fields: text, source_lang, target_lang") |
|
|
|
|
|
|
|
|
if LOAD_BALANCER_ENABLED: |
|
|
local_server = server_registry.servers.get(SERVER_ID) |
|
|
|
|
|
|
|
|
if (local_server and |
|
|
local_server.current_jobs >= local_server.max_concurrent_jobs): |
|
|
|
|
|
|
|
|
available_server = server_registry.get_available_server() |
|
|
|
|
|
if available_server and available_server.id != SERVER_ID: |
|
|
|
|
|
try: |
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.post( |
|
|
f"{available_server.url}/api/translate/heavy", |
|
|
json=data, |
|
|
timeout=aiohttp.ClientTimeout(total=10) |
|
|
) as response: |
|
|
if response.status == 200: |
|
|
result = await response.json() |
|
|
logger.info(f"Routed request {request_id} to server {available_server.id}") |
|
|
return result |
|
|
else: |
|
|
logger.warning(f"Failed to route to {available_server.id}: {response.status}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error routing to {available_server.id}: {e}") |
|
|
|
|
|
|
|
|
job_id = await translation_queue.add_job( |
|
|
text=text, |
|
|
source_lang=source_lang, |
|
|
target_lang=target_lang, |
|
|
request_id=request_id, |
|
|
auto_charge=auto_charge, |
|
|
notification_url=notification_url |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"request_id": request_id, |
|
|
"job_id": job_id, |
|
|
"message": "Server busy, request queued for processing", |
|
|
"processing_mode": "queued" |
|
|
} |
|
|
|
|
|
|
|
|
translations[request_id] = { |
|
|
"status": "processing", |
|
|
"progress": 0, |
|
|
"elapsed_time": 0, |
|
|
"message": "Translation in progress...", |
|
|
"server_id": SERVER_ID |
|
|
} |
|
|
|
|
|
|
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_registry.mark_server_busy(SERVER_ID) |
|
|
|
|
|
|
|
|
asyncio.create_task( |
|
|
run_enhanced_translation_job( |
|
|
request_id, text, source_lang, target_lang, |
|
|
notification_url, auto_charge |
|
|
) |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"request_id": request_id, |
|
|
"message": "Translation started on current server", |
|
|
"processing_mode": "local", |
|
|
"server_id": SERVER_ID |
|
|
} |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Error in heavy_translate_enhanced: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/webhook/job-completion") |
|
|
async def job_completion_webhook(data: dict): |
|
|
"""Webhook endpoint for receiving job completion notifications from peer servers""" |
|
|
try: |
|
|
job_id = data.get('job_id') |
|
|
request_id = data.get('request_id') |
|
|
status = data.get('status') |
|
|
result = data.get('result') |
|
|
server_id = data.get('server_id') |
|
|
|
|
|
if not all([job_id, request_id, status]): |
|
|
raise HTTPException(status_code=400, detail="Missing required fields") |
|
|
|
|
|
|
|
|
async with translation_queue.lock: |
|
|
if job_id in translation_queue.active_jobs: |
|
|
job = translation_queue.active_jobs[job_id] |
|
|
|
|
|
if status == 'completed': |
|
|
job.status = JobStatus.COMPLETED |
|
|
job.completed_at = time.time() |
|
|
job.result = result |
|
|
|
|
|
|
|
|
translation_queue.completed_jobs[job_id] = job |
|
|
del translation_queue.active_jobs[job_id] |
|
|
translation_queue.processed_jobs += 1 |
|
|
|
|
|
logger.info(f"Job {job_id} completed on server {server_id}") |
|
|
|
|
|
elif status == 'failed': |
|
|
job.status = JobStatus.FAILED |
|
|
job.completed_at = time.time() |
|
|
job.error = data.get('error', 'Unknown error') |
|
|
|
|
|
|
|
|
translation_queue.failed_jobs[job_id] = job |
|
|
del translation_queue.active_jobs[job_id] |
|
|
translation_queue.failed_job_count += 1 |
|
|
|
|
|
logger.error(f"Job {job_id} failed on server {server_id}") |
|
|
|
|
|
|
|
|
if job.assigned_server: |
|
|
server_registry.mark_server_available(job.assigned_server) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": f"Job {job_id} status updated to {status}" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in job completion webhook: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/api/enhanced-status") |
|
|
async def enhanced_server_status(): |
|
|
"""Get enhanced server status including load balancer information""" |
|
|
try: |
|
|
base_stats = { |
|
|
"server_id": SERVER_ID, |
|
|
"server_url": CURRENT_SERVER_URL, |
|
|
"load_balancer_enabled": LOAD_BALANCER_ENABLED, |
|
|
"model": MODEL_NAME, |
|
|
"device": str(translator.device) if translator else "unknown", |
|
|
"gpu_available": torch.cuda.is_available(), |
|
|
} |
|
|
|
|
|
if LOAD_BALANCER_ENABLED: |
|
|
server_stats = server_registry.get_server_stats() |
|
|
queue_stats = await translation_queue.get_queue_stats() |
|
|
|
|
|
base_stats.update({ |
|
|
"server_registry": server_stats, |
|
|
"queue_stats": queue_stats, |
|
|
"peer_servers": len(PEER_SERVERS) |
|
|
}) |
|
|
else: |
|
|
|
|
|
base_stats.update({ |
|
|
"active_sessions": len(translator.translation_sessions) if translator else 0, |
|
|
"completed_translations": len(translator.completed_translations) if translator else 0, |
|
|
"total_requests": translator.total_requests if translator else 0 |
|
|
}) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
**base_stats, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/translate/distributed") |
|
|
async def distributed_translate(request: TranslationRequest): |
|
|
""" |
|
|
Distributed translation endpoint - routes requests to available servers |
|
|
""" |
|
|
try: |
|
|
if not LOAD_BALANCER_ENABLED: |
|
|
|
|
|
return await translate_text_api(request) |
|
|
|
|
|
|
|
|
local_server = server_registry.servers.get(SERVER_ID) |
|
|
|
|
|
if (local_server and |
|
|
local_server.status == ServerStatus.AVAILABLE and |
|
|
local_server.current_jobs < local_server.max_concurrent_jobs): |
|
|
|
|
|
|
|
|
server_registry.mark_server_busy(SERVER_ID) |
|
|
try: |
|
|
result = perform_translation_internal( |
|
|
request.text, |
|
|
request.source_lang, |
|
|
request.target_lang |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"processed_by": SERVER_ID, |
|
|
"processing_mode": "local", |
|
|
"translated_text": result['translated_text'], |
|
|
"processing_time": result['processing_time'], |
|
|
"chunks_count": result['chunks_count'], |
|
|
"from_cache": result.get('from_cache', False), |
|
|
"character_count": len(request.text), |
|
|
"translation_length": len(result['translated_text']) |
|
|
} |
|
|
finally: |
|
|
server_registry.mark_server_available(SERVER_ID) |
|
|
|
|
|
else: |
|
|
|
|
|
job_id = await translation_queue.add_job( |
|
|
text=request.text, |
|
|
source_lang=request.source_lang, |
|
|
target_lang=request.target_lang, |
|
|
auto_charge=request.auto_charge |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"processing_mode": "queued", |
|
|
"job_id": job_id, |
|
|
"message": "Request queued for processing on available server", |
|
|
"estimated_wait_time": await estimate_queue_wait_time() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in distributed translation: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/translate/queue") |
|
|
async def queue_translate(request: TranslationRequest): |
|
|
""" |
|
|
Force translation through the queue system |
|
|
""" |
|
|
try: |
|
|
job_id = await translation_queue.add_job( |
|
|
text=request.text, |
|
|
source_lang=request.source_lang, |
|
|
target_lang=request.target_lang, |
|
|
auto_charge=request.auto_charge |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"job_id": job_id, |
|
|
"message": "Translation request added to queue", |
|
|
"estimated_wait_time": await estimate_queue_wait_time() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/api/job/{job_id}/status") |
|
|
async def get_job_status(job_id: str): |
|
|
"""Get status of a queued translation job""" |
|
|
try: |
|
|
status = await translation_queue.get_job_status(job_id) |
|
|
|
|
|
if not status: |
|
|
raise HTTPException(status_code=404, detail="Job not found") |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
**status |
|
|
} |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/api/request/{request_id}/status") |
|
|
async def get_request_status(request_id: str): |
|
|
"""Get status by request_id (WordPress compatibility)""" |
|
|
try: |
|
|
status = await translation_queue.get_job_by_request_id(request_id) |
|
|
|
|
|
if not status: |
|
|
raise HTTPException(status_code=404, detail="Request not found") |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
**status |
|
|
} |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/job/{job_id}/cancel") |
|
|
async def cancel_job(job_id: str): |
|
|
"""Cancel a queued translation job""" |
|
|
try: |
|
|
cancelled = await translation_queue.cancel_job(job_id) |
|
|
|
|
|
if cancelled: |
|
|
return { |
|
|
"success": True, |
|
|
"message": f"Job {job_id} cancelled successfully" |
|
|
} |
|
|
else: |
|
|
raise HTTPException(status_code=404, detail="Job not found or cannot be cancelled") |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/api/load-balancer/status") |
|
|
async def load_balancer_status(): |
|
|
"""Get load balancer status""" |
|
|
try: |
|
|
server_stats = server_registry.get_server_stats() |
|
|
queue_stats = await translation_queue.get_queue_stats() |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"load_balancer_enabled": LOAD_BALANCER_ENABLED, |
|
|
"server_registry": server_stats, |
|
|
"queue_stats": queue_stats, |
|
|
"total_servers": len(server_registry.servers), |
|
|
"available_servers": len([s for s in server_registry.servers.values() if s.status == ServerStatus.AVAILABLE]) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/load-balancer/register") |
|
|
async def register_server(server_data: dict): |
|
|
"""Register a new server with the load balancer""" |
|
|
try: |
|
|
server_id = server_data.get("server_id") |
|
|
url = server_data.get("url") |
|
|
max_concurrent_jobs = server_data.get("max_concurrent_jobs", 1) |
|
|
|
|
|
if not all([server_id, url]): |
|
|
raise HTTPException(status_code=400, detail="Missing server_id or url") |
|
|
|
|
|
server_registry.register_server(server_id, url, max_concurrent_jobs) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": f"Server {server_id} registered successfully" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/api/load-balancer/unregister") |
|
|
async def unregister_server(server_data: dict): |
|
|
"""Unregister a server from the load balancer""" |
|
|
try: |
|
|
server_id = server_data.get("server_id") |
|
|
|
|
|
if not server_id: |
|
|
raise HTTPException(status_code=400, detail="Missing server_id") |
|
|
|
|
|
server_registry.unregister_server(server_id) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": f"Server {server_id} unregistered successfully" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
def perform_translation_internal(text: str, source_lang: str, target_lang: str) -> Dict[str, Any]: |
|
|
"""Internal translation function - replace with your actual implementation""" |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
translated_text = f"[TRANSLATED] {text} [{source_lang}->{target_lang}]" |
|
|
|
|
|
return { |
|
|
"translated_text": translated_text, |
|
|
"processing_time": time.time() - start_time, |
|
|
"chunks_count": 1, |
|
|
"from_cache": False |
|
|
} |
|
|
|
|
|
async def translate_text_api(request: TranslationRequest): |
|
|
"""Fallback translation API - replace with your actual implementation""" |
|
|
try: |
|
|
result = perform_translation_internal( |
|
|
request.text, |
|
|
request.source_lang, |
|
|
request.target_lang |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"translated_text": result['translated_text'], |
|
|
"processing_time": result['processing_time'], |
|
|
"chunks_count": result['chunks_count'], |
|
|
"from_cache": result.get('from_cache', False) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
log_level="info" |
|
|
) |