sairika commited on
Commit
efaba82
·
verified ·
1 Parent(s): 06320e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -0
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import sqlite3
5
+ import pandas as pd
6
+ from typing import List, Optional, Dict, Any
7
+ from pathlib import Path
8
+ import asyncio
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, File, UploadFile, HTTPException
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from pydantic import BaseModel
14
+ import uvicorn
15
+
16
+ # Document processing
17
+ import PyPDF2
18
+ import pdfplumber
19
+ from docx import Document
20
+ import pytesseract
21
+ from PIL import Image
22
+
23
+ # ML/AI components
24
+ import torch
25
+ from sentence_transformers import SentenceTransformer
26
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
27
+ import faiss
28
+ import numpy as np
29
+ import pickle
30
+
31
+ # Configuration
32
+ class Config:
33
+ UPLOAD_DIR = "uploads"
34
+ VECTOR_STORE_DIR = "vector_store"
35
+ CHUNK_SIZE = 500
36
+ CHUNK_OVERLAP = 50
37
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
38
+
39
+ # Hugging Face Models (Free)
40
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
41
+ LLM_MODEL = "microsoft/DialoGPT-medium" # For conversational responses
42
+ # Alternative: "google/flan-t5-base" for better text generation
43
+
44
+ config = Config()
45
+
46
+ # Ensure directories exist
47
+ os.makedirs(config.UPLOAD_DIR, exist_ok=True)
48
+ os.makedirs(config.VECTOR_STORE_DIR, exist_ok=True)
49
+
50
+ # Pydantic models
51
+ class QueryRequest(BaseModel):
52
+ question: str
53
+ image_base64: Optional[str] = None
54
+ file_id: Optional[str] = None
55
+
56
+ class QueryResponse(BaseModel):
57
+ answer: str
58
+ context: List[str]
59
+ sources: List[Dict[str, Any]]
60
+ confidence: float
61
+
62
+ class UploadResponse(BaseModel):
63
+ file_id: str
64
+ filename: str
65
+ file_type: str
66
+ chunks_created: int
67
+ message: str
68
+
69
+ # Document Processor Class
70
+ class DocumentProcessor:
71
+ def __init__(self):
72
+ self.embedding_model = SentenceTransformer(config.EMBEDDING_MODEL)
73
+
74
+ def extract_text_from_pdf(self, file_path: str) -> str:
75
+ """Extract text from PDF using pdfplumber"""
76
+ text = ""
77
+ try:
78
+ with pdfplumber.open(file_path) as pdf:
79
+ for page in pdf.pages:
80
+ page_text = page.extract_text()
81
+ if page_text:
82
+ text += page_text + "\n"
83
+ except Exception as e:
84
+ # Fallback to PyPDF2
85
+ with open(file_path, 'rb') as file:
86
+ pdf_reader = PyPDF2.PdfReader(file)
87
+ for page in pdf_reader.pages:
88
+ text += page.extract_text() + "\n"
89
+ return text
90
+
91
+ def extract_text_from_docx(self, file_path: str) -> str:
92
+ """Extract text from Word document"""
93
+ doc = Document(file_path)
94
+ text = ""
95
+ for paragraph in doc.paragraphs:
96
+ text += paragraph.text + "\n"
97
+ return text
98
+
99
+ def extract_text_from_image(self, image_data: bytes) -> str:
100
+ """Extract text from image using OCR"""
101
+ try:
102
+ image = Image.open(io.BytesIO(image_data))
103
+ text = pytesseract.image_to_string(image)
104
+ return text
105
+ except Exception as e:
106
+ raise HTTPException(status_code=400, f"OCR failed: {str(e)}")
107
+
108
+ def extract_text_from_csv(self, file_path: str) -> str:
109
+ """Extract text from CSV"""
110
+ df = pd.read_csv(file_path)
111
+ return df.to_string()
112
+
113
+ def extract_text_from_db(self, file_path: str) -> str:
114
+ """Extract text from SQLite database"""
115
+ conn = sqlite3.connect(file_path)
116
+ text = ""
117
+
118
+ # Get all table names
119
+ cursor = conn.cursor()
120
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
121
+ tables = cursor.fetchall()
122
+
123
+ for table in tables:
124
+ table_name = table[0]
125
+ df = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
126
+ text += f"Table: {table_name}\n"
127
+ text += df.to_string() + "\n\n"
128
+
129
+ conn.close()
130
+ return text
131
+
132
+ def chunk_text(self, text: str) -> List[str]:
133
+ """Split text into chunks with overlap"""
134
+ chunks = []
135
+ words = text.split()
136
+
137
+ for i in range(0, len(words), config.CHUNK_SIZE - config.CHUNK_OVERLAP):
138
+ chunk = " ".join(words[i:i + config.CHUNK_SIZE])
139
+ chunks.append(chunk)
140
+
141
+ return chunks
142
+
143
+ def process_document(self, file_path: str, file_type: str) -> List[str]:
144
+ """Process document based on file type"""
145
+ text = ""
146
+
147
+ if file_type.lower() == '.pdf':
148
+ text = self.extract_text_from_pdf(file_path)
149
+ elif file_type.lower() == '.docx':
150
+ text = self.extract_text_from_docx(file_path)
151
+ elif file_type.lower() == '.txt':
152
+ with open(file_path, 'r', encoding='utf-8') as f:
153
+ text = f.read()
154
+ elif file_type.lower() in ['.jpg', '.jpeg', '.png']:
155
+ with open(file_path, 'rb') as f:
156
+ text = self.extract_text_from_image(f.read())
157
+ elif file_type.lower() == '.csv':
158
+ text = self.extract_text_from_csv(file_path)
159
+ elif file_type.lower() == '.db':
160
+ text = self.extract_text_from_db(file_path)
161
+ else:
162
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_type}")
163
+
164
+ return self.chunk_text(text)
165
+
166
+ # Vector Store Class
167
+ class VectorStore:
168
+ def __init__(self, embedding_model: SentenceTransformer):
169
+ self.embedding_model = embedding_model
170
+ self.dimension = 384 # all-MiniLM-L6-v2 embedding dimension
171
+ self.index = faiss.IndexFlatIP(self.dimension) # Inner product for similarity
172
+ self.chunks = []
173
+ self.metadata = []
174
+
175
+ def add_documents(self, chunks: List[str], file_id: str, filename: str):
176
+ """Add documents to vector store"""
177
+ embeddings = self.embedding_model.encode(chunks)
178
+
179
+ # Normalize embeddings for inner product similarity
180
+ faiss.normalize_L2(embeddings)
181
+
182
+ self.index.add(embeddings.astype(np.float32))
183
+
184
+ for i, chunk in enumerate(chunks):
185
+ self.chunks.append(chunk)
186
+ self.metadata.append({
187
+ 'file_id': file_id,
188
+ 'filename': filename,
189
+ 'chunk_index': i,
190
+ 'text': chunk
191
+ })
192
+
193
+ def search(self, query: str, k: int = 5) -> List[Dict]:
194
+ """Search for similar documents"""
195
+ query_embedding = self.embedding_model.encode([query])
196
+ faiss.normalize_L2(query_embedding)
197
+
198
+ scores, indices = self.index.search(query_embedding.astype(np.float32), k)
199
+
200
+ results = []
201
+ for score, idx in zip(scores[0], indices[0]):
202
+ if idx != -1: # Valid index
203
+ results.append({
204
+ 'text': self.chunks[idx],
205
+ 'metadata': self.metadata[idx],
206
+ 'score': float(score)
207
+ })
208
+
209
+ return results
210
+
211
+ def save(self, path: str):
212
+ """Save vector store to disk"""
213
+ faiss.write_index(self.index, f"{path}/index.faiss")
214
+ with open(f"{path}/data.pkl", 'wb') as f:
215
+ pickle.dump({
216
+ 'chunks': self.chunks,
217
+ 'metadata': self.metadata
218
+ }, f)
219
+
220
+ def load(self, path: str):
221
+ """Load vector store from disk"""
222
+ if os.path.exists(f"{path}/index.faiss"):
223
+ self.index = faiss.read_index(f"{path}/index.faiss")
224
+ with open(f"{path}/data.pkl", 'rb') as f:
225
+ data = pickle.load(f)
226
+ self.chunks = data['chunks']
227
+ self.metadata = data['metadata']
228
+
229
+ # LLM Handler Class
230
+ class LLMHandler:
231
+ def __init__(self):
232
+ # Using Flan-T5 for better text generation
233
+ self.model_name = "google/flan-t5-base"
234
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
235
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
236
+ self.generator = pipeline(
237
+ "text2text-generation",
238
+ model=self.model,
239
+ tokenizer=self.tokenizer,
240
+ max_length=512,
241
+ temperature=0.7,
242
+ do_sample=True
243
+ )
244
+
245
+ def generate_answer(self, question: str, context: List[str]) -> str:
246
+ """Generate answer using LLM"""
247
+ # Construct prompt
248
+ context_text = "\n".join(context[:3]) # Use top 3 contexts
249
+
250
+ prompt = f"""Based on the following context, answer the question accurately and concisely.
251
+
252
+ Context:
253
+ {context_text}
254
+
255
+ Question: {question}
256
+
257
+ Answer:"""
258
+
259
+ try:
260
+ response = self.generator(
261
+ prompt,
262
+ max_length=200,
263
+ num_return_sequences=1,
264
+ pad_token_id=self.tokenizer.eos_token_id
265
+ )
266
+
267
+ answer = response[0]['generated_text']
268
+ # Clean up the answer
269
+ if "Answer:" in answer:
270
+ answer = answer.split("Answer:")[-1].strip()
271
+
272
+ return answer
273
+
274
+ except Exception as e:
275
+ return f"I apologize, but I encountered an error generating the answer: {str(e)}"
276
+
277
+ # Initialize components
278
+ document_processor = DocumentProcessor()
279
+ vector_store = VectorStore(document_processor.embedding_model)
280
+ llm_handler = LLMHandler()
281
+
282
+ # Load existing vector store if available
283
+ vector_store.load(config.VECTOR_STORE_DIR)
284
+
285
+ # FastAPI app
286
+ app = FastAPI(
287
+ title="Smart RAG API",
288
+ description="Retrieval-Augmented Generation API for document Q&A",
289
+ version="1.0.0"
290
+ )
291
+
292
+ app.add_middleware(
293
+ CORSMiddleware,
294
+ allow_origins=["*"],
295
+ allow_credentials=True,
296
+ allow_methods=["*"],
297
+ allow_headers=["*"],
298
+ )
299
+
300
+ @app.post("/upload", response_model=UploadResponse)
301
+ async def upload_file(file: UploadFile = File(...)):
302
+ """Upload and process a document"""
303
+
304
+ # Validate file size
305
+ file_content = await file.read()
306
+ if len(file_content) > config.MAX_FILE_SIZE:
307
+ raise HTTPException(status_code=413, detail="File too large")
308
+
309
+ # Generate file ID
310
+ file_id = str(uuid.uuid4())
311
+ file_extension = Path(file.filename).suffix.lower()
312
+
313
+ # Save file
314
+ file_path = os.path.join(config.UPLOAD_DIR, f"{file_id}_{file.filename}")
315
+ with open(file_path, "wb") as f:
316
+ f.write(file_content)
317
+
318
+ try:
319
+ # Process document
320
+ chunks = document_processor.process_document(file_path, file_extension)
321
+
322
+ # Add to vector store
323
+ vector_store.add_documents(chunks, file_id, file.filename)
324
+
325
+ # Save vector store
326
+ vector_store.save(config.VECTOR_STORE_DIR)
327
+
328
+ return UploadResponse(
329
+ file_id=file_id,
330
+ filename=file.filename,
331
+ file_type=file_extension,
332
+ chunks_created=len(chunks),
333
+ message="File uploaded and processed successfully"
334
+ )
335
+
336
+ except Exception as e:
337
+ # Clean up file on error
338
+ os.remove(file_path)
339
+ raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
340
+
341
+ @app.post("/query", response_model=QueryResponse)
342
+ async def query_documents(request: QueryRequest):
343
+ """Query documents with a question"""
344
+
345
+ question = request.question
346
+
347
+ # Handle image-based questions
348
+ if request.image_base64:
349
+ try:
350
+ # Decode base64 image
351
+ image_data = base64.b64decode(request.image_base64)
352
+
353
+ # Extract text from image
354
+ ocr_text = document_processor.extract_text_from_image(image_data)
355
+
356
+ # Combine question with OCR text
357
+ question = f"{request.question} [Image content: {ocr_text}]"
358
+
359
+ except Exception as e:
360
+ raise HTTPException(status_code=400, detail=f"Image processing failed: {str(e)}")
361
+
362
+ # Search vector store
363
+ search_results = vector_store.search(question, k=5)
364
+
365
+ if not search_results:
366
+ raise HTTPException(status_code=404, detail="No relevant documents found")
367
+
368
+ # Extract context and sources
369
+ contexts = [result['text'] for result in search_results]
370
+ sources = [result['metadata'] for result in search_results]
371
+
372
+ # Generate answer
373
+ answer = llm_handler.generate_answer(request.question, contexts)
374
+
375
+ # Calculate confidence (average similarity score)
376
+ confidence = sum(result['score'] for result in search_results) / len(search_results)
377
+
378
+ return QueryResponse(
379
+ answer=answer,
380
+ context=contexts,
381
+ sources=sources,
382
+ confidence=confidence
383
+ )
384
+
385
+ @app.get("/health")
386
+ async def health_check():
387
+ """Health check endpoint"""
388
+ return {
389
+ "status": "healthy",
390
+ "documents_indexed": len(vector_store.chunks),
391
+ "model_loaded": llm_handler.model is not None
392
+ }
393
+
394
+ @app.get("/")
395
+ async def root():
396
+ """Root endpoint with API information"""
397
+ return {
398
+ "message": "Smart RAG API",
399
+ "version": "1.0.0",
400
+ "endpoints": {
401
+ "/upload": "POST - Upload documents",
402
+ "/query": "POST - Query documents",
403
+ "/health": "GET - Health check"
404
+ }
405
+ }
406
+
407
+ if __name__ == "__main__":
408
+ uvicorn.run(app, host="0.0.0.0", port=8000)