Spaces:
No application file
No application file
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| import faiss | |
| import pickle | |
| import json | |
| from typing import List, Dict, Tuple, Any, Optional | |
| import logging | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import docx | |
| from openpyxl import load_workbook | |
| import PyPDF2 | |
| import fitz | |
| from sentence_transformers import SentenceTransformer | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| app = FastAPI(title="RAG System API", description="Document Search and Query System") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| try: | |
| nltk.download('punkt', quiet=True) | |
| except: | |
| logger.warning("Could not download NLTK data") | |
| class DocumentChunk: | |
| content: str | |
| source: str | |
| page_number: int | |
| chunk_index: int | |
| embedding: np.ndarray = None | |
| class DocumentResponse(BaseModel): | |
| message: str | |
| files: List[str] | |
| success: bool | |
| class QueryRequest(BaseModel): | |
| question: str | |
| top_k: int = 3 | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[str] | |
| confidence: float | |
| chunks: List[Dict[str, Any]] | |
| class DocumentProcessor: | |
| def __init__(self): | |
| self.supported_formats = ['.csv', '.docx', '.xlsx', '.xls', '.pdf', '.txt'] | |
| def load_csv(self, file_path: str) -> List[Dict]: | |
| try: | |
| df = pd.read_csv(file_path) | |
| return [{ | |
| 'content': " | ".join([f"{col}:{val}" for col, val in row.items()]), | |
| 'source': file_path, | |
| 'page_number': idx + 1, | |
| 'metadata': {'row_index': idx} | |
| } for idx, row in df.iterrows()] | |
| except Exception as e: | |
| logger.error(f"CSV Error: {e}") | |
| return [] | |
| def load_txt(self, file_path: str) -> List[Dict]: | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return [{ | |
| 'content': content, | |
| 'source': file_path, | |
| 'page_number': 1, | |
| 'metadata': {'content_type': 'text'} | |
| }] | |
| except Exception as e: | |
| logger.error(f"TXT Error: {e}") | |
| return [] | |
| def load_pdf(self, file_path: str) -> List[Dict]: | |
| try: | |
| doc = fitz.open(file_path) | |
| return [{ | |
| 'content': page.get_text().strip(), | |
| 'source': file_path, | |
| 'page_number': i + 1, | |
| 'metadata': {'total_pages': len(doc)} | |
| } for i, page in enumerate(doc) if page.get_text().strip()] | |
| except Exception as e: | |
| logger.error(f"PDF Error: {e}") | |
| return [] | |
| def load_docx(self, file_path: str) -> List[Dict]: | |
| try: | |
| doc = docx.Document(file_path) | |
| content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
| return [{ | |
| 'content': content, | |
| 'source': file_path, | |
| 'page_number': 1, | |
| 'metadata': {'content_type': 'docx'} | |
| }] | |
| except Exception as e: | |
| logger.error(f"DOCX Error: {e}") | |
| return [] | |
| def load_excel(self, file_path: str) -> List[Dict]: | |
| try: | |
| wb = load_workbook(file_path) | |
| results = [] | |
| for sheet_name in wb.sheetnames: | |
| ws = wb[sheet_name] | |
| data = [] | |
| for row in ws.iter_rows(values_only=True): | |
| if any(cell is not None for cell in row): | |
| data.append([str(cell) if cell is not None else "" for cell in row]) | |
| content = "\n".join([" | ".join(row) for row in data]) | |
| results.append({ | |
| 'content': content, | |
| 'source': f"{file_path} - {sheet_name}", | |
| 'page_number': 1, | |
| 'metadata': {'sheet_name': sheet_name} | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Excel Error: {e}") | |
| return [] | |
| def process_document(self, file_path: str) -> List[Dict]: | |
| ext = Path(file_path).suffix.lower() | |
| if ext == '.csv': | |
| return self.load_csv(file_path) | |
| elif ext == '.txt': | |
| return self.load_txt(file_path) | |
| elif ext == '.pdf': | |
| return self.load_pdf(file_path) | |
| elif ext == '.docx': | |
| return self.load_docx(file_path) | |
| elif ext in ['.xlsx', '.xls']: | |
| return self.load_excel(file_path) | |
| else: | |
| return [] | |
| class TextChunker: | |
| def __init__(self, chunk_size=512, overlap=50): | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| def chunk_document(self, doc_data: Dict) -> List[DocumentChunk]: | |
| content = doc_data['content'] | |
| if len(content) <= self.chunk_size: | |
| return [DocumentChunk( | |
| content=content, | |
| source=doc_data['source'], | |
| page_number=doc_data['page_number'], | |
| chunk_index=0 | |
| )] | |
| sentences = sent_tokenize(content) | |
| chunks, current_chunk = [], "" | |
| for sentence in sentences: | |
| if len(current_chunk) + len(sentence) > self.chunk_size: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = " ".join(current_chunk.split()[-self.overlap:]) + " " + sentence | |
| else: | |
| current_chunk += " " + sentence if current_chunk else sentence | |
| if current_chunk.strip(): | |
| chunks.append(current_chunk.strip()) | |
| return [DocumentChunk( | |
| content=chunk, | |
| source=doc_data['source'], | |
| page_number=doc_data['page_number'], | |
| chunk_index=i | |
| ) for i, chunk in enumerate(chunks)] | |
| class VectorStore: | |
| def __init__(self, dimension=384): | |
| self.index = faiss.IndexFlatIP(dimension) | |
| self.chunks = [] | |
| def add_chunks(self, chunks: List[DocumentChunk], embeddings: np.ndarray): | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings.astype('float32')) | |
| self.chunks.extend(chunks) | |
| def search(self, query_embedding: np.ndarray, k=5) -> List[Tuple[DocumentChunk, float]]: | |
| query_embedding = query_embedding.reshape(1, -1).astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| scores, indices = self.index.search(query_embedding, k) | |
| return [(self.chunks[i], float(s)) for s, i in zip(scores[0], indices[0]) if i < len(self.chunks)] | |
| class RAGSystem: | |
| def __init__(self): | |
| self.processor = DocumentProcessor() | |
| self.chunker = TextChunker() | |
| self.embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.vector_store = VectorStore() | |
| self.llm = None | |
| self.indexed_files = [] | |
| def index_documents(self, file_paths: List[str]) -> bool: | |
| all_chunks = [] | |
| for path in file_paths: | |
| logger.info(f"Processing document: {path}") | |
| docs = self.processor.process_document(path) | |
| for doc in docs: | |
| all_chunks.extend(self.chunker.chunk_document(doc)) | |
| if not all_chunks: | |
| return False | |
| logger.info(f"Generated {len(all_chunks)} chunks, creating embeddings...") | |
| embeddings = self.embedder.encode([c.content for c in all_chunks]) | |
| self.vector_store.add_chunks(all_chunks, embeddings) | |
| self.indexed_files.extend(file_paths) | |
| logger.info(f"Successfully indexed {len(file_paths)} documents") | |
| return True | |
| def query(self, question: str, top_k=3) -> Dict[str, Any]: | |
| logger.info(f"Processing query: {question}") | |
| query_embed = self.embedder.encode([question])[0] | |
| results = self.vector_store.search(query_embed, top_k) | |
| if not results: | |
| return { | |
| 'answer': "No relevant information found in the uploaded documents.", | |
| 'sources': [], | |
| 'confidence': 0.0, | |
| 'chunks': [] | |
| } | |
| context = "\n\n".join([f"Source: {c.source}\nContent: {c.content}" for c, _ in results]) | |
| if self.llm: | |
| prompt = f"Based on the following context, answer the question: {question}\n\nContext:\n{context}\n\nAnswer:" | |
| answer = self.llm(prompt, max_tokens=512, temperature=0.7)['choices'][0]['text'].strip() | |
| else: | |
| answer = f"Based on the retrieved documents, here are the most relevant sections:\n\n{context[:1000]}..." | |
| return { | |
| 'answer': answer, | |
| 'sources': list(set([c.source for c, _ in results])), | |
| 'confidence': float(np.mean([s for _, s in results])), | |
| 'chunks': [{ | |
| 'source': c.source, | |
| 'content': c.content[:500] + ("..." if len(c.content) > 500 else ""), | |
| 'score': float(s) | |
| } for c, s in results] | |
| } | |
| rag = RAGSystem() | |
| async def serve_frontend(): | |
| try: | |
| with open("index.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| return HTMLResponse( | |
| content="<h1>Frontend not found</h1><p>Please place index.html in the same directory as app.py</p>", | |
| status_code=404 | |
| ) | |
| async def upload_files(files: List[UploadFile] = File(...)): | |
| """Upload and process documents""" | |
| saved_paths = [] | |
| os.makedirs("data", exist_ok=True) | |
| try: | |
| for file in files: | |
| file_path = f"data/{file.filename}" | |
| with open(file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| saved_paths.append(file_path) | |
| logger.info(f"Saved file: {file_path}") | |
| success = rag.index_documents(saved_paths) | |
| return DocumentResponse( | |
| message=f"Successfully processed {len(saved_paths)} files" if success else "Failed to process files", | |
| files=[Path(p).name for p in saved_paths], | |
| success=success | |
| ) | |
| except Exception as e: | |
| logger.error(f"Upload error: {e}") | |
| for path in saved_paths: | |
| try: | |
| os.remove(path) | |
| except: | |
| pass | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| async def process_query(query: QueryRequest): | |
| """Process a query against the indexed documents""" | |
| try: | |
| if len(rag.vector_store.chunks) == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No documents have been uploaded and indexed yet. Please upload documents first." | |
| ) | |
| result = rag.query(query.question, query.top_k) | |
| return QueryResponse(**result) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Query error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") | |
| def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "running", | |
| "indexed_files": len(rag.indexed_files), | |
| "total_chunks": len(rag.vector_store.chunks) | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| os.makedirs("data", exist_ok=True) | |
| print("π Starting RAG System Server...") | |
| print("π Frontend will be available at: http://localhost:8000") | |
| print("π API docs available at: http://localhost:8000/docs") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info" | |
| ) |