Update main.py
Browse files
main.py
CHANGED
|
@@ -1,17 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import nest_asyncio
|
| 3 |
from crawl4ai import AsyncWebCrawler
|
| 4 |
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy, LLMExtractionStrategy
|
| 5 |
import json
|
| 6 |
import time
|
| 7 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
nest_asyncio.apply()
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
async def simple_crawl():
|
| 12 |
async with AsyncWebCrawler(verbose=True) as crawler:
|
| 13 |
result = await crawler.arun(url="https://www.nbcnews.com/business")
|
| 14 |
print(len(result.markdown))
|
| 15 |
return result
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from fastapi import FastAPI, HTTPException
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from typing import List, Optional
|
| 5 |
import asyncio
|
| 6 |
import nest_asyncio
|
| 7 |
from crawl4ai import AsyncWebCrawler
|
| 8 |
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy, LLMExtractionStrategy
|
| 9 |
import json
|
| 10 |
import time
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
load_dotenv() # Load environment variables from .env file
|
| 14 |
+
|
| 15 |
+
app = FastAPI()
|
| 16 |
|
| 17 |
nest_asyncio.apply()
|
| 18 |
|
| 19 |
+
class CrawlerInput(BaseModel):
|
| 20 |
+
url: str = Field(..., description="URL to crawl")
|
| 21 |
+
columns: List[str] = Field(..., description="List of required columns")
|
| 22 |
+
descriptions: List[str] = Field(..., description="Descriptions for each column")
|
| 23 |
+
|
| 24 |
+
class CrawlerOutput(BaseModel):
|
| 25 |
+
data: List[dict]
|
| 26 |
+
|
| 27 |
async def simple_crawl():
|
| 28 |
async with AsyncWebCrawler(verbose=True) as crawler:
|
| 29 |
result = await crawler.arun(url="https://www.nbcnews.com/business")
|
| 30 |
print(len(result.markdown))
|
| 31 |
return result
|
| 32 |
+
|
| 33 |
+
@app.post("/crawl", response_model=CrawlerOutput)
|
| 34 |
+
async def crawl(input: CrawlerInput):
|
| 35 |
+
if len(input.columns) != len(input.descriptions):
|
| 36 |
+
raise HTTPException(status_code=400, detail="Number of columns must match number of descriptions")
|
| 37 |
+
|
| 38 |
+
async with AsyncWebCrawler(verbose=True) as crawler:
|
| 39 |
+
result = await crawler.arun(
|
| 40 |
+
url=input.url,
|
| 41 |
+
extraction_strategy=LLMExtractionStrategy(
|
| 42 |
+
provider="openai/gpt-3.5-turbo",
|
| 43 |
+
api_token=os.getenv('OPENAI_API_KEY'),
|
| 44 |
+
extraction_type="schema",
|
| 45 |
+
verbose=True,
|
| 46 |
+
instruction=f"Extract the following information: {', '.join(input.columns)}. Descriptions: {', '.join(input.descriptions)}"
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
extracted_data = json.loads(result.extracted_content)
|
| 51 |
+
return CrawlerOutput(data=extracted_data)
|
| 52 |
+
|
| 53 |
+
@app.get("/test")
|
| 54 |
+
async def test():
|
| 55 |
+
result = await simple_crawl()
|
| 56 |
+
return {"markdown": result.markdown}
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
import uvicorn
|
| 60 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|