OmniContent / api /main.py
Ravindra's picture
update main with dev-v2
59ced9e verified
from hashlib import sha256
import ipaddress
from typing import Any, Literal
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from celery.result import AsyncResult
from api.config import settings
from tasks import generate_content_task, celery_app
import redis
app = FastAPI(title="OmniContent API", version="1.0.0")
# CORS Configuration
origins = [
"http://localhost:3000",
"https://omnicontent-web.vercel.app",
"*"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Redis connection
redis_client = None
try:
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
except Exception as e:
print(f"Redis Connection warning: {e}")
LIMITER_LUA = """
local key = KEYS[1]
local max_limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local current = tonumber(redis.call("GET", key) or "0")
if current >= max_limit then
local ttl = tonumber(redis.call("TTL", key) or 0)
if ttl < 0 then ttl = 0 end
return {0, current, ttl}
end
current = tonumber(redis.call("INCR", key))
if current == 1 then
redis.call("EXPIRE", key, window)
end
local ttl = tonumber(redis.call("TTL", key) or 0)
if ttl < 0 then ttl = 0 end
return {1, current, ttl}
"""
class GenerateRequest(BaseModel):
topic: str
domain: str = "fintech"
tone: str = "professional"
def _resolve_tier(request: Request) -> Literal["openai", "groq"]:
tier = request.headers.get("x-omni-tier", "groq").strip().lower()
return "openai" if tier == "openai" else "groq"
def _hash_identifier(raw: str) -> str:
return sha256(f"{settings.RATE_LIMIT_SALT}:{raw}".encode("utf-8")).hexdigest()[:32]
def _extract_client_ip(request: Request) -> str:
forwarded_for = request.headers.get("x-forwarded-for", "")
if forwarded_for:
candidate = forwarded_for.split(",")[0].strip()
try:
ipaddress.ip_address(candidate)
return candidate
except ValueError:
pass
if request.client and request.client.host:
return request.client.host
return "127.0.0.1"
def _extract_visitor_id(request: Request) -> str | None:
visitor_id = request.headers.get("x-omni-fingerprint", "").strip()
return visitor_id if visitor_id else None
def _build_usage(mode: str, limit: int, used: int, reset_in_seconds: int) -> dict[str, Any]:
remaining = max(limit - used, 0)
return {
"mode": mode,
"limit": limit,
"used": used,
"remaining": remaining,
"reset_in_seconds": max(reset_in_seconds, 0),
}
def _read_usage(limit_key: str, mode: str, limit: int) -> dict[str, Any]:
if redis_client is None:
return _build_usage(mode, limit, 0, 0)
used = int(redis_client.get(limit_key) or 0)
ttl = int(redis_client.ttl(limit_key) or 0)
ttl = max(ttl, 0)
return _build_usage(mode, limit, used, ttl)
def _increment_usage(limit_key: str, mode: str, limit: int, window_seconds: int) -> tuple[bool, dict[str, Any]]:
if redis_client is None:
raise HTTPException(status_code=503, detail="Rate limiter unavailable.")
allowed, used, ttl = redis_client.eval(LIMITER_LUA, 1, limit_key, limit, window_seconds)
usage = _build_usage(mode, limit, int(used), int(ttl))
return bool(int(allowed)), usage
def _raise_limit_error(code: str, usage: dict[str, Any], message: str) -> None:
headers = {"Retry-After": str(usage["reset_in_seconds"])}
raise HTTPException(
status_code=429,
detail={
"code": code,
"message": message,
**usage,
},
headers=headers,
)
def _check_ip_abuse(ip_hash: str, visitor_hash: str | None) -> None:
if redis_client is None or visitor_hash is None:
return
key = f"ip_burn:{ip_hash}"
pipe = redis_client.pipeline()
pipe.sadd(key, visitor_hash)
pipe.scard(key)
pipe.ttl(key)
_, unique_count, ttl = pipe.execute()
unique_count = int(unique_count or 0)
ttl = int(ttl or 0)
if ttl <= 0:
redis_client.expire(key, settings.IP_BURN_WINDOW_SECONDS)
ttl = settings.IP_BURN_WINDOW_SECONDS
if unique_count > settings.IP_BURN_UNIQUE_LIMIT:
usage = {
"mode": "fingerprint",
"limit": settings.FINGERPRINT_DAILY_LIMIT,
"used": settings.FINGERPRINT_DAILY_LIMIT,
"remaining": 0,
"reset_in_seconds": max(ttl, 0),
}
_raise_limit_error(
code="IP_ABUSE_BLOCKED",
usage=usage,
message="Too many unique devices from this IP. Please login to continue.",
)
def _resolve_identity(request: Request) -> tuple[Literal["fingerprint", "ip_fallback"], str, int, str, str | None]:
client_ip = _extract_client_ip(request)
visitor_id = _extract_visitor_id(request)
ip_hash = _hash_identifier(client_ip)
visitor_hash = _hash_identifier(visitor_id) if visitor_id else None
if visitor_hash:
mode: Literal["fingerprint", "ip_fallback"] = "fingerprint"
return mode, f"f_limit:{visitor_hash}", settings.FINGERPRINT_DAILY_LIMIT, ip_hash, visitor_hash
mode = "ip_fallback"
return mode, f"ip_limit:{ip_hash}", settings.IP_FALLBACK_DAILY_LIMIT, ip_hash, None
def _get_guest_usage(request: Request, consume: bool) -> dict[str, Any]:
mode, limit_key, limit, ip_hash, visitor_hash = _resolve_identity(request)
if consume:
_check_ip_abuse(ip_hash=ip_hash, visitor_hash=visitor_hash)
allowed, usage = _increment_usage(
limit_key=limit_key,
mode=mode,
limit=limit,
window_seconds=settings.DAILY_WINDOW_SECONDS,
)
if not allowed:
_raise_limit_error(
code="GUEST_LIMIT_REACHED",
usage=usage,
message="Daily guest limit reached. Please login to continue.",
)
return usage
return _read_usage(limit_key=limit_key, mode=mode, limit=limit)
@app.get("/api/health")
def health_check():
return {"status": "ok", "message": "OmniContent API is running"}
@app.get("/guest-usage")
def guest_usage(request: Request):
return _get_guest_usage(request=request, consume=False)
@app.post("/generate", status_code=202)
def generate_content(request_body: GenerateRequest, request: Request):
"""
Triggers the content generation workflow asynchronously.
"""
tier = _resolve_tier(request)
grounded = tier == "openai"
if tier == "openai":
usage = {
"mode": "authenticated",
"limit": 0,
"used": 0,
"remaining": 0,
"reset_in_seconds": 0,
}
else:
usage = _get_guest_usage(request=request, consume=True)
task = generate_content_task.delay(
request_body.topic,
request_body.domain,
request_body.tone,
tier,
grounded,
)
return {
"task_id": task.id,
"status_url": f"/status/{task.id}",
"usage": usage,
"tier": tier,
}
@app.get("/status/{task_id}")
def get_task_status(task_id: str):
"""
Checks the status of the generation task.
"""
task_result = AsyncResult(task_id, app=celery_app)
try:
status = task_result.status
result = task_result.result if task_result.ready() else None
except ValueError as exc:
# Handles legacy/corrupted Celery backend entries that are missing
# proper exception metadata.
return {
"task_id": task_id,
"status": "FAILURE",
"result": None,
"latest_thought": "Task metadata was invalid. Please retry generation.",
"error": str(exc),
}
normalized_result = None
if result is not None:
normalized_result = str(result)
response = {
"task_id": task_id,
"status": status,
"result": normalized_result,
}
print(f"DEBUG: Task {task_id} - Status: {response['status']}, Result: {response['result']}")
if task_result.status in ["PENDING", "STARTED"]:
if task_result.info and isinstance(task_result.info, dict):
response["latest_thought"] = task_result.info.get("latest_thought", "Initializing...")
else:
response["latest_thought"] = "Waiting for worker..."
elif task_result.status == "FAILURE":
response["error"] = str(task_result.result)
return response