Spaces:
Runtime error
Runtime error
| import fastapi as api | |
| from typing import Annotated | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm | |
| from model.document import Document, PlainTextDocument, JsonDocument | |
| import sys | |
| from model.user import User | |
| from fastapi import FastAPI, File, UploadFile | |
| from di import initialize_di_for_app | |
| import gradio as gr | |
| import os | |
| import json | |
| SETTINGS, STORAGE, EMBEDDING, INDEX = initialize_di_for_app() | |
| user_json_str = STORAGE.load('user.json') | |
| USER = User.parse_raw(user_json_str) | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token") | |
| app = api.FastAPI() | |
| app.openapi_version = "3.0.0" | |
| users = [USER] | |
| async def get_current_user(token: str = api.Depends(oauth2_scheme)): | |
| ''' | |
| Get current user | |
| ''' | |
| for user in users: | |
| if user.user_name == token: | |
| return user | |
| raise api.HTTPException(status_code=401, detail="Invalid authentication credentials") | |
| async def login(form_data: Annotated[OAuth2PasswordRequestForm, api.Depends()]): | |
| ''' | |
| Login to get a token | |
| ''' | |
| return {"access_token": form_data.username} | |
| def create_upload_file(file: UploadFile = api.File(...)) -> Document: | |
| ''' | |
| Upload a file | |
| ''' | |
| fileUrl = f'{USER.user_name}-{file.filename}' | |
| STORAGE.save(fileUrl, file.read()) | |
| # create plainTextDocument if the file is a text file | |
| if file.filename.endswith('.txt'): | |
| return PlainTextDocument( | |
| name=file.filename, | |
| status='uploading', | |
| url=fileUrl, | |
| embedding=EMBEDDING, | |
| storage=STORAGE, | |
| ) | |
| else: | |
| raise api.HTTPException(status_code=400, detail="File type not supported") | |
| ### /api/v1/.well-known | |
| #### Get /openapi.json | |
| # Get the openapi json file | |
| async def get_openapi(): | |
| ''' | |
| otherwise return 401 | |
| ''' | |
| # get a list of document names + description | |
| document_list = [[doc.name, doc.description] for doc in USER.documents] | |
| # get openapi json from api | |
| openapi = app.openapi().copy() | |
| openapi['info']['title'] = 'DocumentSearch' | |
| description = f'''Search documents with a query. | |
| ## Documents | |
| {document_list} | |
| ''' | |
| openapi['info']['description'] = description | |
| # update description in /api/v1/search | |
| openapi['paths']['/api/v1/search']['get']['description'] += f''' | |
| Available documents: | |
| {document_list} | |
| ''' | |
| # filter out unnecessary endpoints | |
| openapi['paths'] = { | |
| '/api/v1/search': openapi['paths']['/api/v1/search'], | |
| } | |
| # remove components | |
| openapi['components'] = {} | |
| # return the openapi json | |
| return openapi | |
| ### /api/v1/document | |
| #### Get /list | |
| # Get the list of documents | |
| # async def get_document_list(user: Annotated[User, api.Depends(get_current_user)]) -> list[Document]: | |
| async def get_document_list() -> list[Document]: | |
| ''' | |
| Get the list of documents | |
| ''' | |
| return USER.documents | |
| #### Post /upload | |
| # Upload a document | |
| # def upload_document(user: Annotated[User, api.Depends(get_current_user)], document: Annotated[Document, api.Depends(create_upload_file)]): | |
| def upload_document(document: Annotated[Document, api.Depends(create_upload_file)]): | |
| ''' | |
| Upload a document | |
| ''' | |
| document.status = 'processing' | |
| INDEX.load_or_update_document(user, document, progress) | |
| document.status = 'done' | |
| USER.documents.append(document) | |
| #### Get /delete | |
| # Delete a document | |
| # async def delete_document(user: Annotated[User, api.Depends(get_current_user)], document_name: str): | |
| async def delete_document(document_name: str): | |
| ''' | |
| Delete a document | |
| ''' | |
| for doc in USER.documents: | |
| if doc.name == document_name: | |
| STORAGE.delete(doc.url) | |
| INDEX.remove_document(USER, doc) | |
| USER.documents.remove(doc) | |
| return | |
| raise api.HTTPException(status_code=404, detail="Document not found") | |
| # Query the index | |
| def search( | |
| # user: Annotated[User, api.Depends(get_current_user)], | |
| query: str, | |
| document_name: str = None, | |
| top_k: int = 10, | |
| threshold: float = 0.5): | |
| ''' | |
| Search documents with a query. It will return [top_k] results with a score higher than [threshold]. | |
| query: the query string, required | |
| document_name: the document name, optional. You can provide this parameter to search in a specific document. | |
| top_k: the number of results to return, optional. Default to 10. | |
| threshold: the threshold of the results, optional. Default to 0.5. | |
| ''' | |
| if document_name: | |
| for doc in USER.documents: | |
| if doc.name == document_name: | |
| return INDEX.query_document(USER, doc, query, top_k, threshold) | |
| raise api.HTTPException(status_code=404, detail="Document not found") | |
| else: | |
| return INDEX.query_index(USER, query, top_k, threshold) | |
| def receive_signal(signalNumber, frame): | |
| print('Received:', signalNumber) | |
| sys.exit() | |
| async def startup_event(): | |
| import signal | |
| signal.signal(signal.SIGINT, receive_signal) | |
| # startup tasks | |
| def exit_event(): | |
| # save USER | |
| STORAGE.save('user.json', USER.model_dump_json()) | |
| print('exit') | |
| user = USER | |
| def gradio_upload_document(file: File): | |
| file_temp_path = file.name | |
| # load file | |
| file_name = os.path.basename(file_temp_path) | |
| fileUrl = f'{USER.user_name}-{file_name}' | |
| with open(file_temp_path, 'r', encoding='utf-8') as f: | |
| STORAGE.save(fileUrl, f.read()) | |
| # create plainTextDocument if the file is a text file | |
| doc = None | |
| if file_name.endswith('.txt'): | |
| doc = PlainTextDocument( | |
| name=file_name, | |
| status='uploading', | |
| url=fileUrl, | |
| embedding=EMBEDDING, | |
| storage=STORAGE, | |
| ) | |
| elif file_name.endswith('.json'): | |
| doc = JsonDocument( | |
| name=file_name, | |
| status='uploading', | |
| url=fileUrl, | |
| embedding=EMBEDDING, | |
| storage=STORAGE, | |
| ) | |
| else: | |
| raise api.HTTPException(status_code=400, detail="File type not supported") | |
| doc.status = 'processing' | |
| INDEX.load_or_update_document(user, doc) | |
| doc.status = 'done' | |
| USER.documents.append(doc) | |
| return f'uploaded {file_name}' | |
| def gradio_query(query: str, document_name: str = None, top_k: int = 10, threshold: float = 0.5): | |
| res_or_exception = search(query, document_name, top_k, threshold) | |
| if isinstance(res_or_exception, api.HTTPException): | |
| raise res_or_exception | |
| # convert to json string | |
| records = [record.model_dump(mode='json') for record in res_or_exception] | |
| return json.dumps(records, indent=4) | |
| with gr.Blocks() as ui: | |
| gr.Markdown("#llm-memory") | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ## LLM Memory | |
| """) | |
| with gr.Row(): | |
| user_name = gr.Label(label="User name", value=USER.user_name) | |
| # url to .well-known/openapi.json | |
| gr.Label(label=".wellknown/openapi.json", value=f"/api/v1/.well-known/openapi.json") | |
| # with gr.Tab("avaiable documents"): | |
| # available_documents = gr.Label(label="avaiable documents", value="avaiable documents") | |
| # refresh_btn = gr.Button(label="refresh", type="button") | |
| # refresh_btn.click(lambda: '\r\n'.join([doc.name for doc in USER.documents]), None, available_documents) | |
| # documents = USER.documents | |
| # for document in documents: | |
| # gr.Label(label=document.name, value=document.name) | |
| # with gr.Tab("upload document"): | |
| # with gr.Tab("upload .txt document"): | |
| # file = gr.File(label="upload document", type="file", file_types=[".txt"]) | |
| # output = gr.Label(label="output", value="output") | |
| # upload_btn = gr.Button("upload document", type="button") | |
| # upload_btn.click(gradio_upload_document, file, output) | |
| # with gr.Tab("upload .json document"): | |
| # gr.Markdown( | |
| # """ | |
| # The json document should be a list of objects, each object should have a `content` field. If you want to add more fields, you can add them in the `meta_data` field. | |
| # For example: | |
| # ```json | |
| # [ | |
| # { | |
| # "content": "hello world", | |
| # "meta_data": { | |
| # "title": "hello world", | |
| # "author": "llm-memory" | |
| # } | |
| # }, | |
| # { | |
| # "content": "hello world" | |
| # "meta_data": { | |
| # "title": "hello world", | |
| # "author": "llm-memory" | |
| # } | |
| # } | |
| # ] | |
| # ``` | |
| # ## Note | |
| # - The `meta_data` should be a dict which both keys and values are strings. | |
| # """) | |
| # file = gr.File(label="upload document", type="file", file_types=[".json"]) | |
| # output = gr.Label(label="output", value="output") | |
| # upload_btn = gr.Button("upload document", type="button") | |
| # upload_btn.click(gradio_upload_document, file, output) | |
| with gr.Tab("search"): | |
| query = gr.Textbox(label="search", placeholder="Query") | |
| document = gr.Dropdown(label="document", choices=[None] + [doc.name for doc in USER.documents], placeholder="document, optional") | |
| top_k = gr.Number(label="top_k", placeholder="top_k, optional", value=10) | |
| threshold = gr.Number(label="threshold", placeholder="threshold, optional", value=0.5) | |
| output = gr.Code(label="output", language="json", value="output") | |
| query_btn = gr.Button("Query") | |
| query_btn.click(gradio_query, [query, document, top_k, threshold], output, api_name="search") | |
| gradio_app = gr.routes.App.create_app(ui) | |
| app.mount("/", gradio_app) | |
| ui.launch() |