from hashlib import sha256 import ipaddress from typing import Any, Literal from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from celery.result import AsyncResult from api.config import settings from tasks import generate_content_task, celery_app import redis app = FastAPI(title="OmniContent API", version="1.0.0") # CORS Configuration origins = [ "http://localhost:3000", "https://omnicontent-web.vercel.app", "*" ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Redis connection redis_client = None try: redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) except Exception as e: print(f"Redis Connection warning: {e}") LIMITER_LUA = """ local key = KEYS[1] local max_limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = tonumber(redis.call("GET", key) or "0") if current >= max_limit then local ttl = tonumber(redis.call("TTL", key) or 0) if ttl < 0 then ttl = 0 end return {0, current, ttl} end current = tonumber(redis.call("INCR", key)) if current == 1 then redis.call("EXPIRE", key, window) end local ttl = tonumber(redis.call("TTL", key) or 0) if ttl < 0 then ttl = 0 end return {1, current, ttl} """ class GenerateRequest(BaseModel): topic: str domain: str = "fintech" tone: str = "professional" def _resolve_tier(request: Request) -> Literal["openai", "groq"]: tier = request.headers.get("x-omni-tier", "groq").strip().lower() return "openai" if tier == "openai" else "groq" def _hash_identifier(raw: str) -> str: return sha256(f"{settings.RATE_LIMIT_SALT}:{raw}".encode("utf-8")).hexdigest()[:32] def _extract_client_ip(request: Request) -> str: forwarded_for = request.headers.get("x-forwarded-for", "") if forwarded_for: candidate = forwarded_for.split(",")[0].strip() try: ipaddress.ip_address(candidate) return candidate except ValueError: pass if request.client and request.client.host: return request.client.host return "127.0.0.1" def _extract_visitor_id(request: Request) -> str | None: visitor_id = request.headers.get("x-omni-fingerprint", "").strip() return visitor_id if visitor_id else None def _build_usage(mode: str, limit: int, used: int, reset_in_seconds: int) -> dict[str, Any]: remaining = max(limit - used, 0) return { "mode": mode, "limit": limit, "used": used, "remaining": remaining, "reset_in_seconds": max(reset_in_seconds, 0), } def _read_usage(limit_key: str, mode: str, limit: int) -> dict[str, Any]: if redis_client is None: return _build_usage(mode, limit, 0, 0) used = int(redis_client.get(limit_key) or 0) ttl = int(redis_client.ttl(limit_key) or 0) ttl = max(ttl, 0) return _build_usage(mode, limit, used, ttl) def _increment_usage(limit_key: str, mode: str, limit: int, window_seconds: int) -> tuple[bool, dict[str, Any]]: if redis_client is None: raise HTTPException(status_code=503, detail="Rate limiter unavailable.") allowed, used, ttl = redis_client.eval(LIMITER_LUA, 1, limit_key, limit, window_seconds) usage = _build_usage(mode, limit, int(used), int(ttl)) return bool(int(allowed)), usage def _raise_limit_error(code: str, usage: dict[str, Any], message: str) -> None: headers = {"Retry-After": str(usage["reset_in_seconds"])} raise HTTPException( status_code=429, detail={ "code": code, "message": message, **usage, }, headers=headers, ) def _check_ip_abuse(ip_hash: str, visitor_hash: str | None) -> None: if redis_client is None or visitor_hash is None: return key = f"ip_burn:{ip_hash}" pipe = redis_client.pipeline() pipe.sadd(key, visitor_hash) pipe.scard(key) pipe.ttl(key) _, unique_count, ttl = pipe.execute() unique_count = int(unique_count or 0) ttl = int(ttl or 0) if ttl <= 0: redis_client.expire(key, settings.IP_BURN_WINDOW_SECONDS) ttl = settings.IP_BURN_WINDOW_SECONDS if unique_count > settings.IP_BURN_UNIQUE_LIMIT: usage = { "mode": "fingerprint", "limit": settings.FINGERPRINT_DAILY_LIMIT, "used": settings.FINGERPRINT_DAILY_LIMIT, "remaining": 0, "reset_in_seconds": max(ttl, 0), } _raise_limit_error( code="IP_ABUSE_BLOCKED", usage=usage, message="Too many unique devices from this IP. Please login to continue.", ) def _resolve_identity(request: Request) -> tuple[Literal["fingerprint", "ip_fallback"], str, int, str, str | None]: client_ip = _extract_client_ip(request) visitor_id = _extract_visitor_id(request) ip_hash = _hash_identifier(client_ip) visitor_hash = _hash_identifier(visitor_id) if visitor_id else None if visitor_hash: mode: Literal["fingerprint", "ip_fallback"] = "fingerprint" return mode, f"f_limit:{visitor_hash}", settings.FINGERPRINT_DAILY_LIMIT, ip_hash, visitor_hash mode = "ip_fallback" return mode, f"ip_limit:{ip_hash}", settings.IP_FALLBACK_DAILY_LIMIT, ip_hash, None def _get_guest_usage(request: Request, consume: bool) -> dict[str, Any]: mode, limit_key, limit, ip_hash, visitor_hash = _resolve_identity(request) if consume: _check_ip_abuse(ip_hash=ip_hash, visitor_hash=visitor_hash) allowed, usage = _increment_usage( limit_key=limit_key, mode=mode, limit=limit, window_seconds=settings.DAILY_WINDOW_SECONDS, ) if not allowed: _raise_limit_error( code="GUEST_LIMIT_REACHED", usage=usage, message="Daily guest limit reached. Please login to continue.", ) return usage return _read_usage(limit_key=limit_key, mode=mode, limit=limit) @app.get("/api/health") def health_check(): return {"status": "ok", "message": "OmniContent API is running"} @app.get("/guest-usage") def guest_usage(request: Request): return _get_guest_usage(request=request, consume=False) @app.post("/generate", status_code=202) def generate_content(request_body: GenerateRequest, request: Request): """ Triggers the content generation workflow asynchronously. """ tier = _resolve_tier(request) grounded = tier == "openai" if tier == "openai": usage = { "mode": "authenticated", "limit": 0, "used": 0, "remaining": 0, "reset_in_seconds": 0, } else: usage = _get_guest_usage(request=request, consume=True) task = generate_content_task.delay( request_body.topic, request_body.domain, request_body.tone, tier, grounded, ) return { "task_id": task.id, "status_url": f"/status/{task.id}", "usage": usage, "tier": tier, } @app.get("/status/{task_id}") def get_task_status(task_id: str): """ Checks the status of the generation task. """ task_result = AsyncResult(task_id, app=celery_app) try: status = task_result.status result = task_result.result if task_result.ready() else None except ValueError as exc: # Handles legacy/corrupted Celery backend entries that are missing # proper exception metadata. return { "task_id": task_id, "status": "FAILURE", "result": None, "latest_thought": "Task metadata was invalid. Please retry generation.", "error": str(exc), } normalized_result = None if result is not None: normalized_result = str(result) response = { "task_id": task_id, "status": status, "result": normalized_result, } print(f"DEBUG: Task {task_id} - Status: {response['status']}, Result: {response['result']}") if task_result.status in ["PENDING", "STARTED"]: if task_result.info and isinstance(task_result.info, dict): response["latest_thought"] = task_result.info.get("latest_thought", "Initializing...") else: response["latest_thought"] = "Waiting for worker..." elif task_result.status == "FAILURE": response["error"] = str(task_result.result) return response