Spaces:
Sleeping
Sleeping
| import aiofiles | |
| import asyncio | |
| from datetime import datetime | |
| import json | |
| import os | |
| from threading import Thread | |
| import time | |
| from typing import Dict, List, Any | |
| import uuid | |
| from flask import Blueprint, request | |
| from server.constant.constants import (MAX_LOCAL_FILE_BATCH_LENGTH, | |
| MAX_FILE_SIZE, LOCAL_FILE_DOWNLOAD_DIR, | |
| STATIC_DIR, FILE_LOADER_EXTENSIONS, | |
| MAX_CONCURRENT_WRITES, | |
| LOCAL_FILE_PROCESS_FAILED) | |
| from server.app.utils.decorators import token_required | |
| from server.app.utils.sqlite_client import get_db_connection | |
| from server.app.utils.diskcache_lock import diskcache_lock | |
| from server.app.utils.hash import generate_md5 | |
| from server.logger.logger_config import my_logger as logger | |
| from server.rag.index.parser.file_loader.csv_loader import AsyncCsvLoader | |
| from server.rag.index.parser.file_loader.docx_loader import AsyncDocxLoader | |
| from server.rag.index.parser.file_loader.epub_loader import AsyncEpubLoader | |
| from server.rag.index.parser.file_loader.html_loader import AsyncHtmlLoader | |
| from server.rag.index.parser.file_loader.md_loader import AsyncMdLoader | |
| from server.rag.index.parser.file_loader.mobi_loader import AsyncMobiLoader | |
| from server.rag.index.parser.file_loader.pdf_loader import AsyncPdfLoader | |
| from server.rag.index.parser.file_loader.pptx_loader import AsyncPptxLoader | |
| from server.rag.index.parser.file_loader.txt_loader import AsyncTxtLoader | |
| from server.rag.index.parser.file_loader.xlsx_loader import AsyncXlsxLoader | |
| from server.rag.index.parser.file_parser.markdown_parser import AsyncTextParser | |
| URL_PREFIX = os.getenv('URL_PREFIX') | |
| files_bp = Blueprint('files', __name__, url_prefix='/open_kf_api/files') | |
| async def write_file_async(file_path: str, content: bytes) -> None: | |
| """Asynchronously write content to a file.""" | |
| async with aiofiles.open(file_path, 'wb') as f: | |
| await f.write(content) | |
| async def parse_file_content_async(file_path: str, file_extension: str, | |
| file_md5: str, | |
| id_url_info: Dict[str, Any]) -> None: | |
| file_loader_obj = None | |
| if file_extension == ".csv": | |
| file_loader_obj = AsyncCsvLoader(file_path=file_path) | |
| elif file_extension == ".docx": | |
| file_loader_obj = AsyncDocxLoader(file_path=file_path) | |
| elif file_extension == ".epub": | |
| file_loader_obj = AsyncEpubLoader(file_path=file_path) | |
| elif file_extension == ".html": | |
| file_loader_obj = AsyncHtmlLoader(file_path=file_path) | |
| elif file_extension == ".md": | |
| file_loader_obj = AsyncMdLoader(file_path=file_path) | |
| elif file_extension == ".mobi": | |
| file_loader_obj = AsyncMobiLoader(file_path=file_path) | |
| elif file_extension == ".pdf": | |
| file_loader_obj = AsyncPdfLoader(file_path=file_path) | |
| elif file_extension == ".pptx": | |
| file_loader_obj = AsyncPptxLoader(file_path=file_path) | |
| elif file_extension == ".txt": | |
| file_loader_obj = AsyncTxtLoader(file_path=file_path) | |
| elif file_extension == ".xlsx": | |
| file_loader_obj = AsyncXlsxLoader(file_path=file_path) | |
| if file_loader_obj: | |
| content = await file_loader_obj.get_content() | |
| doc_id = id_url_info[file_md5]["id"] | |
| url = id_url_info[file_md5]["url"] | |
| text_parser_obj = AsyncTextParser() | |
| if content: | |
| await text_parser_obj.add_content(doc_id=doc_id, | |
| content=content, | |
| url=url) | |
| else: | |
| # if os.path.exists(file_path): | |
| # os.remove(file_path) | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| try: | |
| with diskcache_lock.lock(): | |
| cur.execute( | |
| 'UPDATE t_local_file_tab SET doc_status = ?, content_md5 = ' | |
| ' WHERE id = ?', | |
| (LOCAL_FILE_PROCESS_FAILED, doc_id)) | |
| conn.commit() | |
| except Exception as e: | |
| logger.error(f"Process discache_lock exception: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| finally: | |
| if conn: | |
| conn.close() | |
| else: | |
| logger.error(f"file_extension: '{file_extension}' is illegal!") | |
| async def add_files_limited_by_semaphore( | |
| file_data: List[Dict[str, str]], | |
| id_url_info: Dict[str, Any], | |
| max_concurrent_writes: int = 5) -> None: | |
| """Save files with limited concurrency using a semaphore.""" | |
| semaphore = asyncio.Semaphore(max_concurrent_writes) | |
| async def semaphore_write(data, id_url_info): | |
| async with semaphore: | |
| await write_file_async(data['file_path'], data['content']) | |
| # Create the event loop in the thread where this function is called | |
| loop = asyncio.get_event_loop() | |
| # Create and schedule coroutine tasks | |
| tasks = [ | |
| loop.create_task(semaphore_write(data, id_url_info)) | |
| for data in file_data | |
| ] | |
| await asyncio.gather(*tasks) | |
| for data in file_data: | |
| await parse_file_content_async(data['file_path'], | |
| data['file_extension'], | |
| data['file_md5'], id_url_info) | |
| def add_local_file_info(file_data: List[Dict[str, str]], | |
| id_url_info: Dict[str, Any], | |
| max_concurrent_writes: int = 5): | |
| """Synchronously calls the asynchronous save function to write files with limited concurrency.""" | |
| logger.info( | |
| f"[DOWNLOAD FILE] add_local_file_info begin, id_url_info: {id_url_info}" | |
| ) | |
| beg_time = int(time.time()) | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete( | |
| add_files_limited_by_semaphore(file_data, id_url_info, | |
| max_concurrent_writes)) | |
| loop.close() | |
| timecost = int(time.time()) - beg_time | |
| logger.warning( | |
| f"[DOWNLOAD FILE] add_local_file_info end, timecost is {timecost}") | |
| def submit_local_file_list(): | |
| if 'file_list' not in request.files: | |
| return {'retcode': -20000, 'message': 'file_list is required'} | |
| file_list = request.files.getlist('file_list') | |
| if len(file_list) > MAX_LOCAL_FILE_BATCH_LENGTH: | |
| logger.error( | |
| f'Too many files uploaded, the maximum is {MAX_LOCAL_FILE_BATCH_LENGTH}!' | |
| ) | |
| return { | |
| 'retcode': | |
| -20001, | |
| 'message': | |
| f'Too many files uploaded, the maximum is {MAX_LOCAL_FILE_BATCH_LENGTH}!' | |
| } | |
| file_data = [] | |
| md5_set = set() | |
| conn = None | |
| try: | |
| for file_ in file_list: | |
| file_.seek(0, os.SEEK_END) | |
| file_size = file_.tell() | |
| file_.seek(0) | |
| if file_size > MAX_FILE_SIZE: | |
| logger.error( | |
| f'File {file_.filename} exceeds the size limit of {MAX_FILE_SIZE} bytes!' | |
| ) | |
| return { | |
| 'retcode': | |
| -20002, | |
| 'message': | |
| f'File {file_.filename} exceeds the size limit of {MAX_FILE_SIZE} bytes!' | |
| } | |
| if file_size == 0: | |
| logger.error(f'File {file_.filename} is empty!') | |
| return { | |
| 'retcode': -20002, | |
| 'message': f'File {file_.filename} is empty!' | |
| } | |
| _, file_extension = os.path.splitext(file_.filename) | |
| if file_extension.lower() not in FILE_LOADER_EXTENSIONS: | |
| logger.error( | |
| f"Unsupported file extension '{file_extension}' for {file_.filename}" | |
| ) | |
| return { | |
| 'retcode': | |
| -20003, | |
| 'message': | |
| f"Unsupported file extension '{file_extension}' for {file_.filename}" | |
| } | |
| file_content = file_.read() | |
| file_md5 = generate_md5(file_content) | |
| if file_md5 in md5_set: | |
| logger.error( | |
| f'Local duplicate file detected: {file_.filename}') | |
| return { | |
| 'retcode': -20004, | |
| 'message': | |
| f'Local duplicate file detected: {file_.filename}' | |
| } | |
| md5_set.add(file_md5) | |
| day_folder = datetime.now().strftime("%Y_%m_%d") | |
| unique_folder = str(uuid.uuid4()) | |
| save_directory = os.path.join(STATIC_DIR, LOCAL_FILE_DOWNLOAD_DIR, | |
| day_folder, unique_folder) | |
| os.makedirs(save_directory, exist_ok=True) | |
| file_path = os.path.join(save_directory, file_.filename) | |
| file_url = f"{URL_PREFIX}{STATIC_DIR}/{LOCAL_FILE_DOWNLOAD_DIR}/{day_folder}/{unique_folder}/{file_.filename}" | |
| file_data.append({ | |
| 'filename': file_.filename, | |
| 'file_extension': file_extension.lower(), | |
| 'content': file_content, | |
| 'file_md5': file_md5, | |
| 'file_path': file_path, | |
| 'file_url': file_url, | |
| 'file_size': file_size | |
| }) | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| placeholders = ', '.join(['?'] * len(md5_set)) | |
| cur.execute( | |
| f"SELECT content_md5 FROM t_local_file_tab WHERE doc_status = 4 and content_md5 IN ({placeholders})", | |
| tuple(md5_set)) | |
| existing_md5 = {row[0] for row in cur.fetchall()} | |
| duplicate_files = [ | |
| data['filename'] for data in file_data | |
| if data['file_md5'] in existing_md5 | |
| ] | |
| if duplicate_files: | |
| logger.error( | |
| f'Duplicate files found: {", ".join(duplicate_files)}') | |
| return { | |
| 'retcode': -20005, | |
| 'message': | |
| f'Duplicate files found: {", ".join(duplicate_files)}' | |
| } | |
| insert_data = [] | |
| timestamp = int(time.time()) | |
| for data in file_data: | |
| insert_data.append( | |
| (data['file_url'], data['filename'], data['file_path'], | |
| data['file_extension'], data['file_size'], data['file_md5'], | |
| 1, timestamp, timestamp)) | |
| try: | |
| with diskcache_lock.lock(): | |
| cur.executemany( | |
| ''' | |
| INSERT INTO t_local_file_tab (url, origin_file_name, file_path, file_type, content_length, content_md5, doc_status, ctime, mtime) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| ''', insert_data) | |
| conn.commit() | |
| except Exception as e: | |
| logger.error(f"Process discache_lock exception: {e}") | |
| return { | |
| 'retcode': -30000, | |
| 'message': f'An error occurred: {e}', | |
| 'data': {} | |
| } | |
| cur.execute( | |
| f"SELECT id, url, content_md5 FROM t_local_file_tab WHERE content_md5 IN ({placeholders})", | |
| tuple(md5_set)) | |
| rows = cur.fetchall() | |
| inserted_ids = [row["id"] for row in rows] | |
| id_url_info = {} | |
| for row in rows: | |
| id_url_info[row["content_md5"]] = { | |
| "id": row["id"], | |
| "url": row["url"] | |
| } | |
| Thread(target=add_local_file_info, | |
| args=(file_data, id_url_info, MAX_CONCURRENT_WRITES)).start() | |
| return { | |
| 'retcode': 0, | |
| 'message': 'Files uploaded and metadata saved successfully', | |
| 'data': { | |
| 'file_id_list': inserted_ids | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| return { | |
| 'retcode': -30000, | |
| 'message': f'An error occurred: {e}', | |
| 'data': {} | |
| } | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_local_file_list(): | |
| data = request.json | |
| file_id_list = data.get('id_list', None) # Make site an optional parameter | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| if file_id_list: | |
| placeholders = ', '.join(['?'] * len(file_id_list)) | |
| cur.execute( | |
| f"SELECT id, url, origin_file_name, file_type, content_length, doc_status, ctime, mtime FROM t_local_file_tab WHERE id IN ({placeholders})", | |
| file_id_list) | |
| else: | |
| cur.execute( | |
| "SELECT id, url, origin_file_name, file_type, content_length, doc_status, ctime, mtime FROM t_local_file_tab" | |
| ) | |
| rows = cur.fetchall() | |
| response_data = {} | |
| response_data['file_list'] = [dict(row) for row in rows] | |
| return {'retcode': 0, 'message': 'Success', 'data': response_data} | |
| except Exception as e: | |
| logger.error(f"An error occurred while fetching URL list: {e}") | |
| return { | |
| 'retcode': -30000, | |
| 'message': f'An error occurred: {e}', | |
| 'data': {} | |
| } | |
| finally: | |
| if conn: | |
| conn.close() | |
| async def delete_local_file_info_async(file_dict: Dict[int, str]) -> None: | |
| text_parser = AsyncTextParser() | |
| for doc_id in file_dict: | |
| await text_parser.delete_content(doc_id) | |
| file_path = file_dict[doc_id] | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| def delete_local_file_info(file_dict: Dict[int, str]) -> None: | |
| logger.info( | |
| f"[DOWNLOAD FILE] delete_local_file_info beg, file_dict: {file_dict}") | |
| beg_time = int(time.time()) | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(delete_local_file_info_async(file_dict)) | |
| loop.close() | |
| timecost = int(time.time()) - beg_time | |
| logger.warning( | |
| f"[DOWNLOAD FILE] delete_local_file_info end, timecost is {timecost}") | |
| def delete_local_file_list(): | |
| data = request.json | |
| file_id_list = data.get('id_list') | |
| if not file_id_list: | |
| return { | |
| 'retcode': -20000, | |
| 'message': 'id_list is required', | |
| 'data': {} | |
| } | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| placeholders = ', '.join(['?'] * len(file_id_list)) | |
| cur.execute( | |
| f"SELECT id, file_path FROM t_local_file_tab WHERE id IN ({placeholders})", | |
| file_id_list) | |
| file_dict = {row['id']: row['file_path'] for row in cur.fetchall()} | |
| # Use threading to avoid blocking the Flask application | |
| Thread(target=delete_local_file_info, args=(file_dict, )).start() | |
| return { | |
| 'retcode': 0, | |
| 'message': 'Started deleting the local file list embeddings.', | |
| 'data': {} | |
| } | |
| except Exception as e: | |
| logger.error(f"An error occurred while deleting local file list: {e}") | |
| return { | |
| 'retcode': -30000, | |
| 'message': f'An error occurred: {e}', | |
| 'data': {} | |
| } | |
| finally: | |
| if conn: | |
| conn.close() | |
| def get_local_file_sub_content_list(): | |
| data = request.json | |
| file_id = data.get('id') | |
| page = data.get('page') | |
| page_size = data.get('page_size') | |
| # Validate mandatory parameters | |
| if None in (file_id, page, page_size): | |
| return { | |
| 'retcode': -20000, | |
| 'message': 'Missing mandatory parameters', | |
| 'data': {} | |
| } | |
| if not isinstance(page, int) or not isinstance( | |
| page_size, int) or page < 1 or page_size < 1: | |
| return { | |
| 'retcode': -20001, | |
| 'message': 'Invalid page or page_size parameters', | |
| 'data': {} | |
| } | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Calculate total count | |
| cur.execute( | |
| 'SELECT COUNT(*) FROM t_local_file_chunk_tab WHERE file_id = ?', | |
| (file_id, )) | |
| total_count = cur.fetchone()[0] | |
| # Calculate the starting point for the query | |
| start = (page - 1) * page_size | |
| # Retrieve the specified page of records | |
| cur.execute( | |
| ''' | |
| SELECT chunk_index as "index", content, content_length | |
| FROM t_local_file_chunk_tab | |
| WHERE file_id = ? | |
| ORDER BY chunk_index | |
| LIMIT ? OFFSET ?''', (file_id, page_size, start)) | |
| rows = cur.fetchall() | |
| # Convert rows to dictionaries | |
| record_list = [dict(row) for row in rows] | |
| return { | |
| "retcode": 0, | |
| "message": "success", | |
| "data": { | |
| "total_count": total_count, | |
| "sub_content_list": record_list | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"An error occurred: {e}") | |
| return {'retcode': -30001, 'message': 'Database exception', 'data': {}} | |
| finally: | |
| if conn: | |
| conn.close() | |