Spaces:
Sleeping
Sleeping
| """ | |
| Celery task workers for async inference and batch processing. | |
| Handles long-running DICOM processing jobs with progress tracking via Redis. | |
| Run worker with: celery -A tasks worker --loglevel=info | |
| """ | |
| import logging | |
| import os | |
| import shutil | |
| import datetime | |
| import ssl | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any | |
| from zoneinfo import ZoneInfo | |
| # Ensure the app directory is in the Python path so imports work in worker processes | |
| APP_DIR = Path(__file__).parent.absolute() | |
| if str(APP_DIR) not in sys.path: | |
| sys.path.insert(0, str(APP_DIR)) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| from celery import Celery, current_task | |
| logger = logging.getLogger(__name__) | |
| IST = ZoneInfo("Asia/Kolkata") | |
| def _now_ist() -> datetime.datetime: | |
| return datetime.datetime.now(IST).replace(tzinfo=None) | |
| def _env_int(name: str, default: int | None = None, *, minimum: int | None = None) -> int | None: | |
| raw = os.environ.get(name) | |
| if raw is None: | |
| return default | |
| try: | |
| value = int(raw) | |
| if minimum is not None and value < minimum: | |
| return default | |
| return value | |
| except ValueError: | |
| return default | |
| # Extract Redis URL from environment | |
| REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") | |
| # Initialize Celery app | |
| celery_app = Celery( | |
| "ich_tasks", | |
| broker=REDIS_URL, | |
| backend=REDIS_URL, | |
| ) | |
| # Configure Celery with SSL support for Upstash Redis | |
| ssl_config = None | |
| redis_backend_ssl = None | |
| if REDIS_URL.startswith("rediss://"): | |
| ssl_config = {"ssl_cert_reqs": ssl.CERT_NONE} | |
| redis_backend_ssl = {"ssl_cert_reqs": ssl.CERT_NONE} | |
| celery_app.conf.update( | |
| broker_use_ssl=ssl_config, | |
| redis_backend_use_ssl=redis_backend_ssl, | |
| task_serializer="json", | |
| accept_content=["json"], | |
| result_serializer="json", | |
| timezone="UTC", | |
| enable_utc=True, | |
| task_track_started=True, | |
| task_time_limit=3600, # 1 hour hard limit | |
| task_soft_time_limit=3300, # 55 min soft limit | |
| result_expires=86400, # 24 hours | |
| worker_prefetch_multiplier=1, # Prevent long-running tasks from getting stuck behind each other | |
| task_acks_late=True, # Only acknowledge task after it completely finishes | |
| ) | |
| extra_conf: dict[str, Any] = {} | |
| worker_concurrency = _env_int("ICH_CELERY_CONCURRENCY", None, minimum=1) | |
| worker_prefetch = _env_int("ICH_CELERY_PREFETCH_MULTIPLIER", None, minimum=1) | |
| if worker_concurrency is not None: | |
| extra_conf["worker_concurrency"] = worker_concurrency | |
| if worker_prefetch is not None: | |
| extra_conf["worker_prefetch_multiplier"] = worker_prefetch | |
| if extra_conf: | |
| celery_app.conf.update(**extra_conf) | |
| def _iter_batches(items: list[str], batch_size: int) -> list[list[str]]: | |
| return [items[i:i + batch_size] for i in range(0, len(items), batch_size)] | |
| def process_dicom_batch( | |
| self, | |
| batch_id: str, | |
| dcm_paths: list[str], | |
| user_id: int, | |
| temp_dir: str | None = None, | |
| ) -> dict[str, Any]: | |
| """ | |
| Process a batch of DICOM files asynchronously with progress tracking. | |
| Args: | |
| batch_id: Unique identifier for this batch job | |
| dcm_paths: List of DICOM file paths to process | |
| user_id: User ID for audit and data isolation | |
| temp_dir: Optional temporary directory to clean up after | |
| Returns: | |
| Dictionary with final batch status and results matching frontend expectations | |
| """ | |
| # Import here to avoid circular imports. Add diagnostics to help debug | |
| # ModuleNotFoundError issues when Celery workers can't find `app_new`. | |
| try: | |
| # Ensure APP_DIR is present in sys.path for worker subprocesses | |
| if str(APP_DIR) not in sys.path: | |
| sys.path.insert(0, str(APP_DIR)) | |
| logger.info(f"Inserted APP_DIR into sys.path: {APP_DIR}") | |
| else: | |
| logger.info(f"APP_DIR already in sys.path: {APP_DIR}") | |
| logger.info(f"tasks.py APP_DIR={APP_DIR}") | |
| logger.info(f"sys.path (first 10): {sys.path[:10]}") | |
| # List files in the app dir for visibility | |
| try: | |
| files = [p.name for p in Path(APP_DIR).iterdir() if p.exists()] | |
| logger.info(f"APP_DIR contents: {files[:50]}") | |
| except Exception as _e: | |
| logger.warning(f"Could not list APP_DIR contents: {_e}") | |
| from app_new import app, _run_inference_on_dcm | |
| from auth_utils import log_audit | |
| from models import ScreeningUpload, db | |
| except Exception as e: | |
| logger.error("Failed importing application modules inside Celery worker:\n" + traceback.format_exc()) | |
| raise | |
| total = len(dcm_paths) | |
| succeeded_ids = [] | |
| failed_ids = [] | |
| started_at = _now_ist().isoformat() | |
| logger.info(f"Batch {batch_id} starting: {total} files for user {user_id}") | |
| try: | |
| with app.app_context(): | |
| use_gpu_batch = False | |
| batch_size = 1 | |
| _infer_images_batch = None | |
| _persist_inference_result = None | |
| try: | |
| from app_new import ( | |
| GPU_BATCH_SIZE, | |
| _gpu_batch_ready, | |
| _infer_images_batch, | |
| _persist_inference_result, | |
| ) | |
| use_gpu_batch = _gpu_batch_ready() and total > 1 | |
| batch_size = max(1, GPU_BATCH_SIZE) | |
| except Exception: | |
| use_gpu_batch = False | |
| if use_gpu_batch and _infer_images_batch and _persist_inference_result: | |
| logger.info( | |
| "GPU batch inference enabled (size=%s); per-image traces are skipped.", | |
| batch_size, | |
| ) | |
| processed = 0 | |
| revoked = False | |
| for chunk in _iter_batches(dcm_paths, batch_size): | |
| if revoked: | |
| break | |
| paths = [Path(p) for p in chunk] | |
| upload_records: list[ScreeningUpload] = [] | |
| for path in paths: | |
| request_ctx = current_task.request | |
| is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( | |
| getattr(request_ctx, "revoked", False) | |
| ) | |
| if is_revoked: | |
| logger.info(f"Batch {batch_id} revoked, stopping") | |
| revoked = True | |
| break | |
| upload_record = ScreeningUpload( | |
| user_id=user_id, | |
| file_name=path.name, | |
| original_filename=path.name, | |
| file_size=path.stat().st_size if path.exists() else None, | |
| file_path=str(path), | |
| processing_status="processing", | |
| ) | |
| db.session.add(upload_record) | |
| db.session.commit() | |
| upload_records.append(upload_record) | |
| if revoked: | |
| break | |
| try: | |
| batch_results = _infer_images_batch(paths) | |
| except Exception as exc: | |
| logger.error( | |
| f"Batch {batch_id}: GPU batch inference failed — {exc}", | |
| exc_info=True, | |
| ) | |
| for path, upload_record in zip(paths, upload_records, strict=False): | |
| image_id = path.stem | |
| db.session.rollback() | |
| upload_record.processing_status = "failed" | |
| try: | |
| db.session.commit() | |
| except Exception: | |
| db.session.rollback() | |
| failed_ids.append(image_id) | |
| processed += 1 | |
| self.update_state( | |
| state="PROGRESS", | |
| meta={ | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "running", | |
| "total": total, | |
| "processed": processed, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": "", | |
| "started_at": started_at, | |
| "finished_at": None, | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| }, | |
| ) | |
| continue | |
| for (path, upload_record), (img_rgb, inference) in zip( | |
| zip(paths, upload_records, strict=False), | |
| batch_results, | |
| strict=False, | |
| ): | |
| image_id = path.stem | |
| self.update_state( | |
| state="PROGRESS", | |
| meta={ | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "running", | |
| "total": total, | |
| "processed": processed, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": image_id, | |
| "started_at": started_at, | |
| "finished_at": None, | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| }, | |
| ) | |
| try: | |
| report = _persist_inference_result( | |
| image_id, | |
| user_id, | |
| upload_record.id, | |
| img_rgb, | |
| inference, | |
| ) | |
| if report: | |
| upload_record.processing_status = "completed" | |
| db.session.commit() | |
| succeeded_ids.append(image_id) | |
| else: | |
| upload_record.processing_status = "failed" | |
| db.session.commit() | |
| failed_ids.append(image_id) | |
| except Exception as exc: | |
| logger.error(f"Batch {batch_id}: failed {image_id} — {exc}") | |
| db.session.rollback() | |
| upload_record.processing_status = "failed" | |
| try: | |
| db.session.commit() | |
| except Exception: | |
| db.session.rollback() | |
| failed_ids.append(image_id) | |
| processed += 1 | |
| self.update_state( | |
| state="PROGRESS", | |
| meta={ | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "running", | |
| "total": total, | |
| "processed": processed, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": "", | |
| "started_at": started_at, | |
| "finished_at": None, | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| }, | |
| ) | |
| else: | |
| for i, path_str in enumerate(dcm_paths, 1): | |
| # Check if task was revoked (compat across Celery versions) | |
| request_ctx = current_task.request | |
| is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( | |
| getattr(request_ctx, "revoked", False) | |
| ) | |
| if is_revoked: | |
| logger.info(f"Batch {batch_id} revoked, stopping") | |
| break | |
| path = Path(path_str) | |
| image_id = path.stem | |
| upload_record = ScreeningUpload( | |
| user_id=user_id, | |
| file_name=path.name, | |
| original_filename=path.name, | |
| file_size=path.stat().st_size if path.exists() else None, | |
| file_path=str(path), | |
| processing_status="processing", | |
| ) | |
| db.session.add(upload_record) | |
| db.session.commit() | |
| # Update Celery task state with progress (matches _BATCHES format for frontend) | |
| self.update_state( | |
| state="PROGRESS", | |
| meta={ | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "running", | |
| "total": total, | |
| "processed": i - 1, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": image_id, | |
| "started_at": started_at, | |
| "finished_at": None, | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| }, | |
| ) | |
| try: | |
| report, _ = _run_inference_on_dcm(path, user_id, upload_record.id) | |
| if report: | |
| upload_record.processing_status = "completed" | |
| db.session.commit() | |
| succeeded_ids.append(image_id) | |
| else: | |
| upload_record.processing_status = "failed" | |
| db.session.commit() | |
| failed_ids.append(image_id) | |
| except Exception as e: | |
| logger.error(f"Batch {batch_id}: failed {image_id} — {e}") | |
| db.session.rollback() | |
| upload_record.processing_status = "failed" | |
| try: | |
| db.session.commit() | |
| except Exception: | |
| db.session.rollback() | |
| failed_ids.append(image_id) | |
| # Update after processing each file | |
| self.update_state( | |
| state="PROGRESS", | |
| meta={ | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "running", | |
| "total": total, | |
| "processed": i, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": "", | |
| "started_at": started_at, | |
| "finished_at": None, | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| }, | |
| ) | |
| # Cleanup temporary directory if provided | |
| if temp_dir and Path(temp_dir).exists(): | |
| try: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| logger.info(f"Cleaned up temp_dir: {temp_dir}") | |
| except Exception as e: | |
| logger.warning(f"Failed to clean temp_dir {temp_dir}: {e}") | |
| # Log final audit result | |
| with app.app_context(): | |
| audit_status = "success" if len(failed_ids) == 0 else "partial" | |
| log_audit( | |
| "batch_processing_completed", | |
| user_id=user_id, | |
| details=f"batch_id={batch_id}, processed={total}, succeeded={len(succeeded_ids)}, failed={len(failed_ids)}", | |
| status=audit_status, | |
| ) | |
| # Return final result matching _BATCHES format for frontend compatibility | |
| result = { | |
| "batch_id": batch_id, | |
| "user_id": user_id, | |
| "status": "completed", | |
| "total": total, | |
| "processed": total, | |
| "succeeded": len(succeeded_ids), | |
| "failed_ids": list(failed_ids), | |
| "image_ids": list(succeeded_ids), | |
| "current_file": "", | |
| "started_at": started_at, | |
| "finished_at": _now_ist().isoformat(), | |
| "error": None, | |
| "temp_dir": temp_dir, | |
| } | |
| logger.info( | |
| f"Batch {batch_id} complete: {len(succeeded_ids)}/{total} succeeded, " | |
| f"{len(failed_ids)} failed" | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Batch {batch_id} error: {e}", exc_info=True) | |
| with app.app_context(): | |
| log_audit( | |
| "batch_processing_failed", | |
| user_id=user_id, | |
| details=f"batch_id={batch_id}, error={str(e)}", | |
| status="failure", | |
| ) | |
| raise | |
| def health_check() -> str: | |
| """Simple health check task for monitoring.""" | |
| return "Celery worker is healthy" | |
| def cleanup_expired_otps(): | |
| """Periodic task to delete expired OTPs from the database.""" | |
| from app_new import app | |
| from models import db, PendingOtp, now_ist | |
| with app.app_context(): | |
| try: | |
| deleted = PendingOtp.query.filter(PendingOtp.expires_at < now_ist()).delete() | |
| db.session.commit() | |
| if deleted > 0: | |
| logger.info("Cleaned up %d expired OTP rows.", deleted) | |
| except Exception as exc: | |
| db.session.rollback() | |
| logger.error("Error cleaning up OTPs: %s", exc) | |
| celery_app.conf.beat_schedule = { | |
| 'cleanup-expired-otps-every-15-mins': { | |
| 'task': 'tasks.cleanup_expired_otps', | |
| 'schedule': 900.0, # 15 minutes in seconds | |
| }, | |
| } | |