Spaces:
Running
Running
| from hashlib import sha256 | |
| import ipaddress | |
| from typing import Any, Literal | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from celery.result import AsyncResult | |
| from api.config import settings | |
| from tasks import generate_content_task, celery_app | |
| import redis | |
| app = FastAPI(title="OmniContent API", version="1.0.0") | |
| # CORS Configuration | |
| origins = [ | |
| "http://localhost:3000", | |
| "https://omnicontent-web.vercel.app", | |
| "*" | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Redis connection | |
| redis_client = None | |
| try: | |
| redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) | |
| except Exception as e: | |
| print(f"Redis Connection warning: {e}") | |
| LIMITER_LUA = """ | |
| local key = KEYS[1] | |
| local max_limit = tonumber(ARGV[1]) | |
| local window = tonumber(ARGV[2]) | |
| local current = tonumber(redis.call("GET", key) or "0") | |
| if current >= max_limit then | |
| local ttl = tonumber(redis.call("TTL", key) or 0) | |
| if ttl < 0 then ttl = 0 end | |
| return {0, current, ttl} | |
| end | |
| current = tonumber(redis.call("INCR", key)) | |
| if current == 1 then | |
| redis.call("EXPIRE", key, window) | |
| end | |
| local ttl = tonumber(redis.call("TTL", key) or 0) | |
| if ttl < 0 then ttl = 0 end | |
| return {1, current, ttl} | |
| """ | |
| class GenerateRequest(BaseModel): | |
| topic: str | |
| domain: str = "fintech" | |
| tone: str = "professional" | |
| def _resolve_tier(request: Request) -> Literal["openai", "groq"]: | |
| tier = request.headers.get("x-omni-tier", "groq").strip().lower() | |
| return "openai" if tier == "openai" else "groq" | |
| def _hash_identifier(raw: str) -> str: | |
| return sha256(f"{settings.RATE_LIMIT_SALT}:{raw}".encode("utf-8")).hexdigest()[:32] | |
| def _extract_client_ip(request: Request) -> str: | |
| forwarded_for = request.headers.get("x-forwarded-for", "") | |
| if forwarded_for: | |
| candidate = forwarded_for.split(",")[0].strip() | |
| try: | |
| ipaddress.ip_address(candidate) | |
| return candidate | |
| except ValueError: | |
| pass | |
| if request.client and request.client.host: | |
| return request.client.host | |
| return "127.0.0.1" | |
| def _extract_visitor_id(request: Request) -> str | None: | |
| visitor_id = request.headers.get("x-omni-fingerprint", "").strip() | |
| return visitor_id if visitor_id else None | |
| def _build_usage(mode: str, limit: int, used: int, reset_in_seconds: int) -> dict[str, Any]: | |
| remaining = max(limit - used, 0) | |
| return { | |
| "mode": mode, | |
| "limit": limit, | |
| "used": used, | |
| "remaining": remaining, | |
| "reset_in_seconds": max(reset_in_seconds, 0), | |
| } | |
| def _read_usage(limit_key: str, mode: str, limit: int) -> dict[str, Any]: | |
| if redis_client is None: | |
| return _build_usage(mode, limit, 0, 0) | |
| used = int(redis_client.get(limit_key) or 0) | |
| ttl = int(redis_client.ttl(limit_key) or 0) | |
| ttl = max(ttl, 0) | |
| return _build_usage(mode, limit, used, ttl) | |
| def _increment_usage(limit_key: str, mode: str, limit: int, window_seconds: int) -> tuple[bool, dict[str, Any]]: | |
| if redis_client is None: | |
| raise HTTPException(status_code=503, detail="Rate limiter unavailable.") | |
| allowed, used, ttl = redis_client.eval(LIMITER_LUA, 1, limit_key, limit, window_seconds) | |
| usage = _build_usage(mode, limit, int(used), int(ttl)) | |
| return bool(int(allowed)), usage | |
| def _raise_limit_error(code: str, usage: dict[str, Any], message: str) -> None: | |
| headers = {"Retry-After": str(usage["reset_in_seconds"])} | |
| raise HTTPException( | |
| status_code=429, | |
| detail={ | |
| "code": code, | |
| "message": message, | |
| **usage, | |
| }, | |
| headers=headers, | |
| ) | |
| def _check_ip_abuse(ip_hash: str, visitor_hash: str | None) -> None: | |
| if redis_client is None or visitor_hash is None: | |
| return | |
| key = f"ip_burn:{ip_hash}" | |
| pipe = redis_client.pipeline() | |
| pipe.sadd(key, visitor_hash) | |
| pipe.scard(key) | |
| pipe.ttl(key) | |
| _, unique_count, ttl = pipe.execute() | |
| unique_count = int(unique_count or 0) | |
| ttl = int(ttl or 0) | |
| if ttl <= 0: | |
| redis_client.expire(key, settings.IP_BURN_WINDOW_SECONDS) | |
| ttl = settings.IP_BURN_WINDOW_SECONDS | |
| if unique_count > settings.IP_BURN_UNIQUE_LIMIT: | |
| usage = { | |
| "mode": "fingerprint", | |
| "limit": settings.FINGERPRINT_DAILY_LIMIT, | |
| "used": settings.FINGERPRINT_DAILY_LIMIT, | |
| "remaining": 0, | |
| "reset_in_seconds": max(ttl, 0), | |
| } | |
| _raise_limit_error( | |
| code="IP_ABUSE_BLOCKED", | |
| usage=usage, | |
| message="Too many unique devices from this IP. Please login to continue.", | |
| ) | |
| def _resolve_identity(request: Request) -> tuple[Literal["fingerprint", "ip_fallback"], str, int, str, str | None]: | |
| client_ip = _extract_client_ip(request) | |
| visitor_id = _extract_visitor_id(request) | |
| ip_hash = _hash_identifier(client_ip) | |
| visitor_hash = _hash_identifier(visitor_id) if visitor_id else None | |
| if visitor_hash: | |
| mode: Literal["fingerprint", "ip_fallback"] = "fingerprint" | |
| return mode, f"f_limit:{visitor_hash}", settings.FINGERPRINT_DAILY_LIMIT, ip_hash, visitor_hash | |
| mode = "ip_fallback" | |
| return mode, f"ip_limit:{ip_hash}", settings.IP_FALLBACK_DAILY_LIMIT, ip_hash, None | |
| def _get_guest_usage(request: Request, consume: bool) -> dict[str, Any]: | |
| mode, limit_key, limit, ip_hash, visitor_hash = _resolve_identity(request) | |
| if consume: | |
| _check_ip_abuse(ip_hash=ip_hash, visitor_hash=visitor_hash) | |
| allowed, usage = _increment_usage( | |
| limit_key=limit_key, | |
| mode=mode, | |
| limit=limit, | |
| window_seconds=settings.DAILY_WINDOW_SECONDS, | |
| ) | |
| if not allowed: | |
| _raise_limit_error( | |
| code="GUEST_LIMIT_REACHED", | |
| usage=usage, | |
| message="Daily guest limit reached. Please login to continue.", | |
| ) | |
| return usage | |
| return _read_usage(limit_key=limit_key, mode=mode, limit=limit) | |
| def health_check(): | |
| return {"status": "ok", "message": "OmniContent API is running"} | |
| def guest_usage(request: Request): | |
| return _get_guest_usage(request=request, consume=False) | |
| def generate_content(request_body: GenerateRequest, request: Request): | |
| """ | |
| Triggers the content generation workflow asynchronously. | |
| """ | |
| tier = _resolve_tier(request) | |
| grounded = tier == "openai" | |
| if tier == "openai": | |
| usage = { | |
| "mode": "authenticated", | |
| "limit": 0, | |
| "used": 0, | |
| "remaining": 0, | |
| "reset_in_seconds": 0, | |
| } | |
| else: | |
| usage = _get_guest_usage(request=request, consume=True) | |
| task = generate_content_task.delay( | |
| request_body.topic, | |
| request_body.domain, | |
| request_body.tone, | |
| tier, | |
| grounded, | |
| ) | |
| return { | |
| "task_id": task.id, | |
| "status_url": f"/status/{task.id}", | |
| "usage": usage, | |
| "tier": tier, | |
| } | |
| def get_task_status(task_id: str): | |
| """ | |
| Checks the status of the generation task. | |
| """ | |
| task_result = AsyncResult(task_id, app=celery_app) | |
| try: | |
| status = task_result.status | |
| result = task_result.result if task_result.ready() else None | |
| except ValueError as exc: | |
| # Handles legacy/corrupted Celery backend entries that are missing | |
| # proper exception metadata. | |
| return { | |
| "task_id": task_id, | |
| "status": "FAILURE", | |
| "result": None, | |
| "latest_thought": "Task metadata was invalid. Please retry generation.", | |
| "error": str(exc), | |
| } | |
| normalized_result = None | |
| if result is not None: | |
| normalized_result = str(result) | |
| response = { | |
| "task_id": task_id, | |
| "status": status, | |
| "result": normalized_result, | |
| } | |
| print(f"DEBUG: Task {task_id} - Status: {response['status']}, Result: {response['result']}") | |
| if task_result.status in ["PENDING", "STARTED"]: | |
| if task_result.info and isinstance(task_result.info, dict): | |
| response["latest_thought"] = task_result.info.get("latest_thought", "Initializing...") | |
| else: | |
| response["latest_thought"] = "Waiting for worker..." | |
| elif task_result.status == "FAILURE": | |
| response["error"] = str(task_result.result) | |
| return response | |