| import asyncio, os |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request |
| from fastapi.responses import JSONResponse |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| from fastapi.templating import Jinja2Templates |
| from fastapi.exceptions import RequestValidationError |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from starlette.responses import FileResponse |
| from fastapi.responses import RedirectResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi import Depends, Security |
|
|
| from pydantic import BaseModel, HttpUrl, Field |
| from typing import Optional, List, Dict, Any, Union |
| import psutil |
| import time |
| import uuid |
| from collections import defaultdict |
| from urllib.parse import urlparse |
| import math |
| import logging |
| from enum import Enum |
| from dataclasses import dataclass |
| import json |
| from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode |
| from crawl4ai.config import MIN_WORD_THRESHOLD |
| from crawl4ai.extraction_strategy import ( |
| LLMExtractionStrategy, |
| CosineStrategy, |
| JsonCssExtractionStrategy, |
| ) |
|
|
| __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TaskStatus(str, Enum): |
| PENDING = "pending" |
| PROCESSING = "processing" |
| COMPLETED = "completed" |
| FAILED = "failed" |
|
|
|
|
| class CrawlerType(str, Enum): |
| BASIC = "basic" |
| LLM = "llm" |
| COSINE = "cosine" |
| JSON_CSS = "json_css" |
|
|
|
|
| class ExtractionConfig(BaseModel): |
| type: CrawlerType |
| params: Dict[str, Any] = {} |
|
|
|
|
| class ChunkingStrategy(BaseModel): |
| type: str |
| params: Dict[str, Any] = {} |
|
|
|
|
| class ContentFilter(BaseModel): |
| type: str = "bm25" |
| params: Dict[str, Any] = {} |
|
|
|
|
| class CrawlRequest(BaseModel): |
| urls: Union[HttpUrl, List[HttpUrl]] |
| word_count_threshold: int = MIN_WORD_THRESHOLD |
| extraction_config: Optional[ExtractionConfig] = None |
| chunking_strategy: Optional[ChunkingStrategy] = None |
| content_filter: Optional[ContentFilter] = None |
| js_code: Optional[List[str]] = None |
| wait_for: Optional[str] = None |
| css_selector: Optional[str] = None |
| screenshot: bool = False |
| magic: bool = False |
| extra: Optional[Dict[str, Any]] = {} |
| session_id: Optional[str] = None |
| cache_mode: Optional[CacheMode] = CacheMode.ENABLED |
| priority: int = Field(default=5, ge=1, le=10) |
| ttl: Optional[int] = 3600 |
| crawler_params: Dict[str, Any] = {} |
|
|
|
|
| @dataclass |
| class TaskInfo: |
| id: str |
| status: TaskStatus |
| result: Optional[Union[CrawlResult, List[CrawlResult]]] = None |
| error: Optional[str] = None |
| created_at: float = time.time() |
| ttl: int = 3600 |
|
|
|
|
| class ResourceMonitor: |
| def __init__(self, max_concurrent_tasks: int = 10): |
| self.max_concurrent_tasks = max_concurrent_tasks |
| self.memory_threshold = 0.85 |
| self.cpu_threshold = 0.90 |
| self._last_check = 0 |
| self._check_interval = 1 |
| self._last_available_slots = max_concurrent_tasks |
|
|
| async def get_available_slots(self) -> int: |
| current_time = time.time() |
| if current_time - self._last_check < self._check_interval: |
| return self._last_available_slots |
|
|
| mem_usage = psutil.virtual_memory().percent / 100 |
| cpu_usage = psutil.cpu_percent() / 100 |
|
|
| memory_factor = max( |
| 0, (self.memory_threshold - mem_usage) / self.memory_threshold |
| ) |
| cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) |
|
|
| self._last_available_slots = math.floor( |
| self.max_concurrent_tasks * min(memory_factor, cpu_factor) |
| ) |
| self._last_check = current_time |
|
|
| return self._last_available_slots |
|
|
|
|
| class TaskManager: |
| def __init__(self, cleanup_interval: int = 300): |
| self.tasks: Dict[str, TaskInfo] = {} |
| self.high_priority = asyncio.PriorityQueue() |
| self.low_priority = asyncio.PriorityQueue() |
| self.cleanup_interval = cleanup_interval |
| self.cleanup_task = None |
|
|
| async def start(self): |
| self.cleanup_task = asyncio.create_task(self._cleanup_loop()) |
|
|
| async def stop(self): |
| if self.cleanup_task: |
| self.cleanup_task.cancel() |
| try: |
| await self.cleanup_task |
| except asyncio.CancelledError: |
| pass |
|
|
| async def add_task(self, task_id: str, priority: int, ttl: int) -> None: |
| task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl) |
| self.tasks[task_id] = task_info |
| queue = self.high_priority if priority > 5 else self.low_priority |
| await queue.put((-priority, task_id)) |
|
|
| async def get_next_task(self) -> Optional[str]: |
| try: |
| |
| _, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1) |
| return task_id |
| except asyncio.TimeoutError: |
| try: |
| |
| _, task_id = await asyncio.wait_for( |
| self.low_priority.get(), timeout=0.1 |
| ) |
| return task_id |
| except asyncio.TimeoutError: |
| return None |
|
|
| def update_task( |
| self, task_id: str, status: TaskStatus, result: Any = None, error: str = None |
| ): |
| if task_id in self.tasks: |
| task_info = self.tasks[task_id] |
| task_info.status = status |
| task_info.result = result |
| task_info.error = error |
|
|
| def get_task(self, task_id: str) -> Optional[TaskInfo]: |
| return self.tasks.get(task_id) |
|
|
| async def _cleanup_loop(self): |
| while True: |
| try: |
| await asyncio.sleep(self.cleanup_interval) |
| current_time = time.time() |
| expired_tasks = [ |
| task_id |
| for task_id, task in self.tasks.items() |
| if current_time - task.created_at > task.ttl |
| and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] |
| ] |
| for task_id in expired_tasks: |
| del self.tasks[task_id] |
| except Exception as e: |
| logger.error(f"Error in cleanup loop: {e}") |
|
|
|
|
| class CrawlerPool: |
| def __init__(self, max_size: int = 10): |
| self.max_size = max_size |
| self.active_crawlers: Dict[AsyncWebCrawler, float] = {} |
| self._lock = asyncio.Lock() |
|
|
| async def acquire(self, **kwargs) -> AsyncWebCrawler: |
| async with self._lock: |
| |
| current_time = time.time() |
| inactive = [ |
| crawler |
| for crawler, last_used in self.active_crawlers.items() |
| if current_time - last_used > 600 |
| ] |
| for crawler in inactive: |
| await crawler.__aexit__(None, None, None) |
| del self.active_crawlers[crawler] |
|
|
| |
| if len(self.active_crawlers) < self.max_size: |
| crawler = AsyncWebCrawler(**kwargs) |
| await crawler.__aenter__() |
| self.active_crawlers[crawler] = current_time |
| return crawler |
|
|
| |
| crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0] |
| self.active_crawlers[crawler] = current_time |
| return crawler |
|
|
| async def release(self, crawler: AsyncWebCrawler): |
| async with self._lock: |
| if crawler in self.active_crawlers: |
| self.active_crawlers[crawler] = time.time() |
|
|
| async def cleanup(self): |
| async with self._lock: |
| for crawler in list(self.active_crawlers.keys()): |
| await crawler.__aexit__(None, None, None) |
| self.active_crawlers.clear() |
|
|
|
|
| class CrawlerService: |
| def __init__(self, max_concurrent_tasks: int = 10): |
| self.resource_monitor = ResourceMonitor(max_concurrent_tasks) |
| self.task_manager = TaskManager() |
| self.crawler_pool = CrawlerPool(max_concurrent_tasks) |
| self._processing_task = None |
|
|
| async def start(self): |
| await self.task_manager.start() |
| self._processing_task = asyncio.create_task(self._process_queue()) |
|
|
| async def stop(self): |
| if self._processing_task: |
| self._processing_task.cancel() |
| try: |
| await self._processing_task |
| except asyncio.CancelledError: |
| pass |
| await self.task_manager.stop() |
| await self.crawler_pool.cleanup() |
|
|
| def _create_extraction_strategy(self, config: ExtractionConfig): |
| if not config: |
| return None |
|
|
| if config.type == CrawlerType.LLM: |
| return LLMExtractionStrategy(**config.params) |
| elif config.type == CrawlerType.COSINE: |
| return CosineStrategy(**config.params) |
| elif config.type == CrawlerType.JSON_CSS: |
| return JsonCssExtractionStrategy(**config.params) |
| return None |
|
|
| async def submit_task(self, request: CrawlRequest) -> str: |
| task_id = str(uuid.uuid4()) |
| await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) |
|
|
| |
| self.task_manager.tasks[task_id].request = request |
|
|
| return task_id |
|
|
| async def _process_queue(self): |
| while True: |
| try: |
| available_slots = await self.resource_monitor.get_available_slots() |
| if False and available_slots <= 0: |
| await asyncio.sleep(1) |
| continue |
|
|
| task_id = await self.task_manager.get_next_task() |
| if not task_id: |
| await asyncio.sleep(1) |
| continue |
|
|
| task_info = self.task_manager.get_task(task_id) |
| if not task_info: |
| continue |
|
|
| request = task_info.request |
| self.task_manager.update_task(task_id, TaskStatus.PROCESSING) |
|
|
| try: |
| crawler = await self.crawler_pool.acquire(**request.crawler_params) |
|
|
| extraction_strategy = self._create_extraction_strategy( |
| request.extraction_config |
| ) |
|
|
| if isinstance(request.urls, list): |
| results = await crawler.arun_many( |
| urls=[str(url) for url in request.urls], |
| word_count_threshold=MIN_WORD_THRESHOLD, |
| extraction_strategy=extraction_strategy, |
| js_code=request.js_code, |
| wait_for=request.wait_for, |
| css_selector=request.css_selector, |
| screenshot=request.screenshot, |
| magic=request.magic, |
| session_id=request.session_id, |
| cache_mode=request.cache_mode, |
| **request.extra, |
| ) |
| else: |
| results = await crawler.arun( |
| url=str(request.urls), |
| extraction_strategy=extraction_strategy, |
| js_code=request.js_code, |
| wait_for=request.wait_for, |
| css_selector=request.css_selector, |
| screenshot=request.screenshot, |
| magic=request.magic, |
| session_id=request.session_id, |
| cache_mode=request.cache_mode, |
| **request.extra, |
| ) |
|
|
| await self.crawler_pool.release(crawler) |
| self.task_manager.update_task( |
| task_id, TaskStatus.COMPLETED, results |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error processing task {task_id}: {str(e)}") |
| self.task_manager.update_task( |
| task_id, TaskStatus.FAILED, error=str(e) |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Error in queue processing: {str(e)}") |
| await asyncio.sleep(1) |
|
|
|
|
| app = FastAPI(title="Crawl4AI API") |
|
|
| |
| origins = ["*"] |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| app.add_middleware(GZipMiddleware, minimum_size=1000) |
|
|
| |
| security = HTTPBearer() |
| CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") |
|
|
|
|
| async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): |
| if not CRAWL4AI_API_TOKEN: |
| return credentials |
| if credentials.credentials != CRAWL4AI_API_TOKEN: |
| raise HTTPException(status_code=401, detail="Invalid token") |
| return credentials |
|
|
|
|
| def secure_endpoint(): |
| """Returns security dependency only if CRAWL4AI_API_TOKEN is set""" |
| return Depends(verify_token) if CRAWL4AI_API_TOKEN else None |
|
|
|
|
| |
| if os.path.exists(__location__ + "/site"): |
| |
| app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") |
|
|
| site_templates = Jinja2Templates(directory=__location__ + "/site") |
|
|
| crawler_service = CrawlerService() |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| await crawler_service.start() |
|
|
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| await crawler_service.stop() |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return RedirectResponse(url="/docs") |
|
|
|
|
| @app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) |
| async def crawl(request: CrawlRequest) -> Dict[str, str]: |
| task_id = await crawler_service.submit_task(request) |
| return {"task_id": task_id} |
|
|
|
|
| @app.get( |
| "/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] |
| ) |
| async def get_task_status(task_id: str): |
| task_info = crawler_service.task_manager.get_task(task_id) |
| if not task_info: |
| raise HTTPException(status_code=404, detail="Task not found") |
|
|
| response = { |
| "status": task_info.status, |
| "created_at": task_info.created_at, |
| } |
|
|
| if task_info.status == TaskStatus.COMPLETED: |
| |
| if isinstance(task_info.result, list): |
| response["results"] = [result.dict() for result in task_info.result] |
| else: |
| response["result"] = task_info.result.dict() |
| elif task_info.status == TaskStatus.FAILED: |
| response["error"] = task_info.error |
|
|
| return response |
|
|
|
|
| @app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) |
| async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: |
| task_id = await crawler_service.submit_task(request) |
|
|
| |
| for _ in range(60): |
| task_info = crawler_service.task_manager.get_task(task_id) |
| if not task_info: |
| raise HTTPException(status_code=404, detail="Task not found") |
|
|
| if task_info.status == TaskStatus.COMPLETED: |
| |
| if isinstance(task_info.result, list): |
| return { |
| "status": task_info.status, |
| "results": [result.dict() for result in task_info.result], |
| } |
| return {"status": task_info.status, "result": task_info.result.dict()} |
|
|
| if task_info.status == TaskStatus.FAILED: |
| raise HTTPException(status_code=500, detail=task_info.error) |
|
|
| await asyncio.sleep(1) |
|
|
| |
| raise HTTPException(status_code=408, detail="Task timed out") |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @app.post( |
| "/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] |
| ) |
| async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: |
| logger.info("Received request to crawl directly.") |
| try: |
| logger.debug("Acquiring crawler from the crawler pool.") |
| crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) |
| logger.debug("Crawler acquired successfully.") |
|
|
| logger.debug("Creating extraction strategy based on the request configuration.") |
| extraction_strategy = crawler_service._create_extraction_strategy( |
| request.extraction_config |
| ) |
| logger.debug("Extraction strategy created successfully.") |
|
|
| try: |
| if isinstance(request.urls, list): |
| logger.info("Processing multiple URLs.") |
| results = await crawler.arun_many( |
| urls=[str(url) for url in request.urls], |
| extraction_strategy=extraction_strategy, |
| js_code=request.js_code, |
| wait_for=request.wait_for, |
| css_selector=request.css_selector, |
| screenshot=request.screenshot, |
| magic=request.magic, |
| cache_mode=request.cache_mode, |
| session_id=request.session_id, |
| **request.extra, |
| ) |
| logger.info("Crawling completed for multiple URLs.") |
| return {"results": [result.dict() for result in results]} |
| else: |
| logger.info("Processing a single URL.") |
| result = await crawler.arun( |
| url=str(request.urls), |
| extraction_strategy=extraction_strategy, |
| js_code=request.js_code, |
| wait_for=request.wait_for, |
| css_selector=request.css_selector, |
| screenshot=request.screenshot, |
| magic=request.magic, |
| cache_mode=request.cache_mode, |
| session_id=request.session_id, |
| **request.extra, |
| ) |
|
|
| logger.info("Crawling completed for a single URL.") |
| return {"result": result.dict()} |
| finally: |
| logger.debug("Releasing crawler back to the pool.") |
| await crawler_service.crawler_pool.release(crawler) |
| logger.debug("Crawler released successfully.") |
| except Exception as e: |
| logger.error(f"Error in direct crawl: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| available_slots = await crawler_service.resource_monitor.get_available_slots() |
| memory = psutil.virtual_memory() |
| return { |
| "status": "healthy", |
| "available_slots": available_slots, |
| "memory_usage": memory.percent, |
| "cpu_usage": psutil.cpu_percent(), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run(app, host="0.0.0.0", port=11235) |
|
|