Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from engine import run_from_api_engine | |
| from ingest import run_from_api_ingest | |
| import uvicorn | |
| import warnings | |
| from typing import Annotated | |
| import json | |
| import argparse | |
| from dotenv import load_dotenv | |
| import box | |
| import yaml | |
| import os | |
| from rich import print | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Import config vars | |
| with open('config.yml', 'r', encoding='utf8') as ymlfile: | |
| cfg = box.Box(yaml.safe_load(ymlfile)) | |
| # add asyncio to the pipeline | |
| app = FastAPI(openapi_url="/api/v1/sparrow-llm/openapi.json", docs_url="/api/v1/sparrow-llm/docs") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| allow_credentials=True | |
| ) | |
| def root(): | |
| return {"message": "Sparrow LLM API"} | |
| async def inference( | |
| fields: Annotated[str, Form()], | |
| agent: Annotated[str, Form()], | |
| types: Annotated[str, Form()] = None, | |
| keywords: Annotated[str, Form()] = None, | |
| index_name: Annotated[str, Form()] = None, | |
| options: Annotated[str, Form()] = None, | |
| group_by_rows: Annotated[bool, Form()] = True, | |
| update_targets: Annotated[bool, Form()] = True, | |
| debug: Annotated[bool, Form()] = False, | |
| sparrow_key: Annotated[str, Form()] = None, | |
| file: UploadFile = File(None) | |
| ): | |
| protected_access = cfg.PROTECTED_ACCESS | |
| if protected_access: | |
| # Retrieve all environment variables that start with 'SPARROW_KEY_' | |
| sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')} | |
| # Check if the provided sparrow_key matches any of the environment variables | |
| if sparrow_key not in sparrow_keys.values(): | |
| raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.") | |
| query = 'retrieve ' + fields | |
| query_types = types | |
| query_inputs_arr = [param.strip() for param in fields.split(',')] if query_types else [] | |
| query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else [] | |
| keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None | |
| options_arr = [param.strip() for param in options.split(',')] if options is not None else None | |
| if not query_types: | |
| query = fields | |
| try: | |
| answer = await run_from_api_engine(agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name, | |
| options_arr, file, group_by_rows, update_targets, debug) | |
| except ValueError as e: | |
| raise HTTPException(status_code=418, detail=str(e)) | |
| try: | |
| if isinstance(answer, (str, bytes, bytearray)): | |
| answer = json.loads(answer) | |
| except json.JSONDecodeError as e: | |
| raise HTTPException(status_code=418, detail=answer) | |
| if debug: | |
| print(f"\nJSON response:\n") | |
| print(answer) | |
| return {"message": answer} | |
| async def ingest( | |
| agent: Annotated[str, Form()], | |
| index_name: Annotated[str, Form()], | |
| sparrow_key: Annotated[str, Form()] = None, | |
| file: UploadFile = File() | |
| ): | |
| protected_access = cfg.PROTECTED_ACCESS | |
| if protected_access: | |
| # Retrieve all environment variables that start with 'SPARROW_KEY_' | |
| sparrow_keys = {key: value for key, value in os.environ.items() if key.startswith('SPARROW_KEY_')} | |
| # Check if the provided sparrow_key matches any of the environment variables | |
| if sparrow_key not in sparrow_keys.values(): | |
| raise HTTPException(status_code=403, detail="Protected access. Agent not allowed.") | |
| try: | |
| answer = await run_from_api_ingest(agent, index_name, file, False) | |
| except ValueError as e: | |
| raise HTTPException(status_code=418, detail=str(e)) | |
| if isinstance(answer, (str, bytes, bytearray)): | |
| answer = json.loads(answer) | |
| return {"message": answer} | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run FastAPI App") | |
| parser.add_argument("-p", "--port", type=int, default=8000, help="Port to run the FastAPI app on") | |
| args = parser.parse_args() | |
| uvicorn.run("api:app", host="0.0.0.0", port=args.port, reload=True) | |
| # run the app with: python api.py --port 8000 | |
| # go to http://127.0.0.1:8000/api/v1/sparrow-llm/docs to see the Swagger UI | |