Spaces:
Sleeping
Sleeping
| from fastapi import UploadFile, File, Form, HTTPException, APIRouter | |
| from typing import List, Optional, Dict, Tuple | |
| import lancedb | |
| from lancedb.pydantic import LanceModel, Vector | |
| from lancedb.embeddings import get_registry | |
| import pandas as pd | |
| from utils import process_pdf_to_chunks | |
| import hashlib | |
| import uuid | |
| import json | |
| from datetime import datetime | |
| from pydantic import BaseModel | |
| import logging | |
| # Create router | |
| router = APIRouter( | |
| prefix="/rag", | |
| tags=["rag"] | |
| ) | |
| # Initialize LanceDB and embedding model | |
| db = lancedb.connect("/tmp/db") | |
| model = get_registry().get("sentence-transformers").create( | |
| name="Snowflake/snowflake-arctic-embed-xs", | |
| device="cpu" | |
| ) | |
| def get_user_collection(user_id: str, collection_name: str) -> str: | |
| """Generate user-specific collection name""" | |
| return f"{user_id}_{collection_name}" | |
| class DocumentChunk(LanceModel): | |
| text: str = model.SourceField() | |
| vector: Vector(model.ndims()) = model.VectorField() | |
| document_id: str | |
| chunk_index: int | |
| file_name: str | |
| file_type: str | |
| created_date: str | |
| collection_id: str | |
| user_id: str | |
| metadata_json: str | |
| char_start: int | |
| char_end: int | |
| page_numbers: List[int] | |
| images: List[str] | |
| class QueryInput(BaseModel): | |
| collection_id: str | |
| query: str | |
| top_k: Optional[int] = 3 | |
| user_id: str | |
| class SearchResult(BaseModel): | |
| text: str | |
| distance: float | |
| metadata: Dict # Added metadata field | |
| class SearchResponse(BaseModel): | |
| results: List[SearchResult] | |
| async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]: | |
| """Process single file and return chunks with metadata""" | |
| content = await file.read() | |
| file_type = file.filename.split('.')[-1].lower() | |
| chunks = [] | |
| doc_id = "" | |
| if file_type == 'pdf': | |
| chunks, doc_id = process_pdf_to_chunks( | |
| pdf_content=content, | |
| file_name=file.filename | |
| ) | |
| elif file_type == 'txt': | |
| doc_id = hashlib.sha256(content).hexdigest()[:4] | |
| text_content = content.decode('utf-8') | |
| chunks = [{ | |
| "text": text_content, | |
| "metadata": { | |
| "created_date": datetime.now().isoformat(), | |
| "file_name": file.filename, | |
| "document_id": doc_id, | |
| "user_id": user_id, | |
| "location": { | |
| "chunk_index": 0, | |
| "char_start": 0, | |
| "char_end": len(text_content), | |
| "pages": [1], | |
| "total_chunks": 1 | |
| }, | |
| "images": [] | |
| } | |
| }] | |
| return chunks, doc_id | |
| async def upload_files( | |
| files: List[UploadFile] = File(...), | |
| collection_name: Optional[str] = Form(None), | |
| user_id: str = Form(...) | |
| ): | |
| try: | |
| collection_id = get_user_collection( | |
| user_id, | |
| collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}" | |
| ) | |
| all_chunks = [] | |
| doc_ids = {} | |
| for file in files: | |
| try: | |
| chunks, doc_id = await process_file(file, collection_id, user_id) | |
| for chunk in chunks: | |
| chunk_data = { | |
| "text": chunk["text"], | |
| "document_id": chunk["metadata"]["document_id"], | |
| "chunk_index": chunk["metadata"]["location"]["chunk_index"], | |
| "file_name": chunk["metadata"]["file_name"], | |
| "file_type": file.filename.split('.')[-1].lower(), | |
| "created_date": chunk["metadata"]["created_date"], | |
| "collection_id": collection_id, | |
| "user_id": user_id, | |
| "metadata_json": json.dumps(chunk["metadata"]), | |
| "char_start": chunk["metadata"]["location"]["char_start"], | |
| "char_end": chunk["metadata"]["location"]["char_end"], | |
| "page_numbers": chunk["metadata"]["location"]["pages"], | |
| "images": chunk["metadata"].get("images", []) | |
| } | |
| all_chunks.append(chunk_data) | |
| doc_ids[doc_id] = file.filename | |
| except Exception as e: | |
| logging.error(f"Error processing file {file.filename}: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error processing file {file.filename}: {str(e)}" | |
| ) | |
| try: | |
| table = db.open_table(collection_id) | |
| except Exception as e: | |
| logging.error(f"Error opening table: {str(e)}") | |
| try: | |
| table = db.create_table( | |
| collection_id, | |
| schema=DocumentChunk, | |
| mode="create" | |
| ) | |
| # Create FTS index on the text column for hybrid search support | |
| # table.create_fts_index( | |
| # field_names="text", | |
| # replace=True, | |
| # tokenizer_name="en_stem", # Use English stemming | |
| # lower_case=True, # Convert text to lowercase | |
| # remove_stop_words=True, # Remove common words like "the", "is", "at" | |
| # writer_heap_size=1024 * 1024 * 1024 # 1GB heap size | |
| # ) | |
| except Exception as e: | |
| logging.error(f"Error creating table: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error creating database table: {str(e)}" | |
| ) | |
| try: | |
| df = pd.DataFrame(all_chunks) | |
| table.add(data=df) | |
| except Exception as e: | |
| logging.error(f"Error adding data to table: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error adding data to database: {str(e)}" | |
| ) | |
| return { | |
| "message": f"Successfully processed {len(files)} files", | |
| "collection_id": collection_id, | |
| "total_chunks": len(all_chunks), | |
| "user_id": user_id, | |
| "document_ids": doc_ids | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Unexpected error during file upload: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected error: {str(e)}" | |
| ) | |
| async def get_document( | |
| collection_id: str, | |
| document_id: str, | |
| user_id: str | |
| ): | |
| try: | |
| table = db.open_table(f"{user_id}_{collection_id}") | |
| except Exception as e: | |
| logging.error(f"Error opening table: {str(e)}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Collection not found: {str(e)}" | |
| ) | |
| try: | |
| chunks = table.to_pandas() | |
| doc_chunks = chunks[ | |
| (chunks['document_id'] == document_id) & | |
| (chunks['user_id'] == user_id) | |
| ].sort_values('chunk_index') | |
| if len(doc_chunks) == 0: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Document {document_id} not found in collection {collection_id}" | |
| ) | |
| return { | |
| "document_id": document_id, | |
| "file_name": doc_chunks.iloc[0]['file_name'], | |
| "chunks": [ | |
| { | |
| "text": row['text'], | |
| "metadata": json.loads(row['metadata_json']) | |
| } | |
| for _, row in doc_chunks.iterrows() | |
| ] | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Error retrieving document: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error retrieving document: {str(e)}" | |
| ) | |
| async def query_collection(input_data: QueryInput): | |
| try: | |
| collection_id = get_user_collection(input_data.user_id, input_data.collection_id) | |
| try: | |
| table = db.open_table(collection_id) | |
| except Exception as e: | |
| logging.error(f"Error opening table: {str(e)}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Collection not found: {str(e)}" | |
| ) | |
| try: | |
| results = ( | |
| table.search(input_data.query) | |
| .where(f"user_id = '{input_data.user_id}'") | |
| .limit(input_data.top_k) | |
| .to_list() | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error searching collection: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error searching collection: {str(e)}" | |
| ) | |
| return SearchResponse(results=[ | |
| SearchResult( | |
| text=r['text'], | |
| distance=float(r['_distance']), | |
| metadata=json.loads(r['metadata_json']) | |
| ) | |
| for r in results | |
| ]) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Unexpected error during query: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected error: {str(e)}" | |
| ) | |
| async def list_collections(user_id: str): | |
| try: | |
| all_collections = db.table_names() | |
| user_collections = [ | |
| c for c in all_collections | |
| if c.startswith(f"{user_id}_") | |
| ] | |
| # Get documents for each collection | |
| collections_info = [] | |
| for collection_name in user_collections: | |
| try: | |
| table = db.open_table(collection_name) | |
| df = table.to_pandas() | |
| # Group by document_id to get unique documents | |
| documents = df.groupby('document_id').agg({ | |
| 'file_name': 'first', | |
| 'created_date': 'first' | |
| }).reset_index() | |
| collections_info.append({ | |
| "collection_id": collection_name.replace(f"{user_id}_", ""), | |
| "documents": [ | |
| { | |
| "document_id": row['document_id'], | |
| "file_name": row['file_name'], | |
| "created_date": row['created_date'] | |
| } | |
| for _, row in documents.iterrows() | |
| ] | |
| }) | |
| except Exception as e: | |
| logging.error(f"Error processing collection {collection_name}: {str(e)}") | |
| continue | |
| return {"collections": collections_info} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_collection(collection_id: str, user_id: str): | |
| try: | |
| full_collection_id = f"{user_id}_{collection_id}" | |
| # Check if collection exists | |
| try: | |
| table = db.open_table(full_collection_id) | |
| except Exception as e: | |
| logging.error(f"Collection not found: {str(e)}") | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Collection {collection_id} not found" | |
| ) | |
| # Verify ownership | |
| if not full_collection_id.startswith(f"{user_id}_"): | |
| logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}") | |
| raise HTTPException( | |
| status_code=403, | |
| detail="Not authorized to delete this collection" | |
| ) | |
| try: | |
| db.drop_table(full_collection_id) | |
| except Exception as e: | |
| logging.error(f"Error deleting collection {collection_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error deleting collection: {str(e)}" | |
| ) | |
| return { | |
| "message": f"Collection {collection_id} deleted successfully", | |
| "collection_id": collection_id | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected error: {str(e)}" | |
| ) | |
| async def query_collection_tool(input_data: QueryInput): | |
| try: | |
| response = await query_collection(input_data) | |
| results = [] | |
| # Access response directly since it's a Pydantic model | |
| for r in response.results: | |
| result_dict = { | |
| "text": r.text, | |
| "distance": r.distance, | |
| "metadata": { | |
| "document_id": r.metadata.get("document_id"), | |
| "chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
| } | |
| } | |
| results.append(result_dict) | |
| return str(results) | |
| except Exception as e: | |
| logging.error(f"Unexpected error during query: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected error: {str(e)}" | |
| ) |