Kalpokoch commited on
Commit
bc8a612
·
1 Parent(s): d1af85f

added vector embedding and query endpoint

Browse files
Files changed (1) hide show
  1. app.py +157 -67
app.py CHANGED
@@ -1,101 +1,191 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import uuid
 
 
 
 
 
 
 
 
 
 
4
  import logging
 
5
  import io
6
- import fitz
 
 
 
 
 
 
 
7
  from PIL import Image
8
  import pytesseract
9
- import numpy as np
10
 
11
- # NEW: Import AI and search libraries
12
- from sentence_transformers import SentenceTransformer
13
  import faiss
 
 
 
 
14
 
15
- # --- Basic Setup (Logging, FastAPI, CORS) ---
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
- app = FastAPI()
 
 
 
 
 
 
 
 
19
  app.add_middleware(
20
- CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
 
 
 
 
21
  )
22
 
23
- # --- AI MODEL LOADING ---
24
- # This happens only once when the app starts.
25
- # 'all-MiniLM-L6-v2' is a great, lightweight model for CPU.
26
  try:
27
- logger.info("Loading sentence-transformer model...")
28
- model = SentenceTransformer('all-MiniLM-L6-v2')
29
- logger.info("Model loaded successfully.")
 
 
 
30
  except Exception as e:
31
- logger.error(f"Failed to load sentence-transformer model: {e}")
32
- model = None
 
33
 
34
- # In-memory session store
35
  SESSION_DATA = {}
36
 
37
- # --- Parsing Functions (parse_pdf, parse_image - keep these as they are) ---
38
- def parse_pdf(content: bytes) -> str: # ... your existing function ...
39
- def parse_image(content: bytes) -> str: # ... your existing function ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # --- NEW: Helper function for chunking text ---
42
  def chunk_text(text: str, chunk_size: int = 256, overlap: int = 32) -> list[str]:
43
  """Splits text into overlapping chunks of words."""
44
  words = text.split()
45
- if not words:
46
- return []
47
- chunks = []
48
- for i in range(0, len(words), chunk_size - overlap):
49
- chunk = " ".join(words[i:i + chunk_size])
50
- chunks.append(chunk)
51
- return chunks
52
-
53
- # --- MODIFIED: The /upload Endpoint ---
54
- @app.post("/upload")
 
 
 
 
 
55
  async def upload_file(file: UploadFile = File(...)):
56
- if not model:
57
- raise HTTPException(status_code=503, detail="AI model is not available.")
58
-
 
59
  session_id = str(uuid.uuid4())
60
- logger.info(f"New upload '{file.filename}'. Creating session_id: {session_id}")
61
  content = await file.read()
62
-
63
- # 1. PARSE (This part is the same as before)
64
- extracted_text = ""
65
- if file.content_type == "application/pdf": extracted_text = parse_pdf(content)
66
- elif file.content_type and file.content_type.startswith("image/"): extracted_text = parse_image(content)
67
- elif file.content_type == "text/plain": extracted_text = content.decode("utf-8")
68
- else: raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
69
-
70
- if not extracted_text.strip():
71
- raise HTTPException(status_code=400, detail="Could not extract any text from the file.")
72
 
73
- # 2. CHUNK
74
- text_chunks = chunk_text(extracted_text)
75
- logger.info(f"Text chunked into {len(text_chunks)} pieces.")
76
- if not text_chunks:
77
- raise HTTPException(status_code=400, detail="Document is empty or too short to be chunked.")
 
78
 
79
- # 3. EMBED
80
- logger.info("Generating embeddings for text chunks...")
81
- embeddings = model.encode(text_chunks, convert_to_numpy=True)
82
- logger.info(f"Embeddings generated with shape: {embeddings.shape}")
83
 
84
- # 4. INDEX
85
- d = embeddings.shape[1] # Dimension of embeddings
86
- index = faiss.IndexFlatL2(d)
87
- index.add(embeddings.astype('float32')) # FAISS requires float32
88
- logger.info(f"FAISS index created with {index.ntotal} vectors.")
 
 
 
 
 
 
89
 
90
- # Store the index AND the original text chunks in the session
91
  SESSION_DATA[session_id] = {
92
- "filename": file.filename,
93
  "chunks": text_chunks,
94
- "index": index.serialize() # Serialize the index for storage
95
  }
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  return {
98
- "session_id": session_id,
99
- "filename": file.filename,
100
- "chunks_created": len(text_chunks)
101
  }
 
1
+ #
2
+ # ---------------- Universal Data AI ----------------
3
+ #
4
+ # Final app.py script
5
+ # Combines:
6
+ # 1. File Upload & Parsing (PDF, Image, Text)
7
+ # 2. Text Chunking
8
+ # 3. Vector Embedding & FAISS Indexing
9
+ # 4. A Query Endpoint for Question Answering
10
+ #
11
+ # Last updated: August 8, 2025
12
+ #
13
+
14
  import logging
15
+ import uuid
16
  import io
17
+
18
+ # FastAPI & Pydantic
19
+ from fastapi import FastAPI, UploadFile, File, HTTPException
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from pydantic import BaseModel
22
+
23
+ # Parsing Libraries
24
+ import fitz # PyMuPDF
25
  from PIL import Image
26
  import pytesseract
 
27
 
28
+ # AI & Search Libraries
29
+ import numpy as np
30
  import faiss
31
+ from sentence_transformers import SentenceTransformer
32
+ from transformers import pipeline
33
+
34
+ # --- 1. INITIAL SETUP & MODEL LOADING ---
35
 
36
+ # Configure logging to see outputs in Hugging Face Space logs
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
+
40
+ # Initialize FastAPI app
41
+ app = FastAPI(
42
+ title="Universal Data AI",
43
+ description="Ephemeral data analysis tool with in-memory vector search.",
44
+ version="1.0.0",
45
+ )
46
+
47
+ # Add CORS middleware to allow frontend requests
48
  app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"], # Allow all for simplicity, can be restricted to your frontend URL
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
  )
55
 
56
+ # Load AI models on startup
57
+ # This can take a moment when the app first boots.
 
58
  try:
59
+ logger.info("Loading AI models...")
60
+ # Model for creating vector embeddings
61
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
62
+ # Pipeline for question-answering
63
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
64
+ logger.info("AI models loaded successfully.")
65
  except Exception as e:
66
+ logger.critical(f"Fatal error: Could not load AI models. {e}")
67
+ embedding_model = None
68
+ qa_pipeline = None
69
 
70
+ # In-memory dictionary to act as our temporary session database
71
  SESSION_DATA = {}
72
 
73
+ # --- 2. DATA MODELS ---
74
+
75
+ class QueryRequest(BaseModel):
76
+ """Defines the request body for the /query endpoint."""
77
+ question: str
78
+
79
+ class UploadResponse(BaseModel):
80
+ """Defines the response for a successful file upload."""
81
+ session_id: str
82
+ filename: str
83
+ chunks_created: int
84
+
85
+ class QueryResponse(BaseModel):
86
+ """Defines the response for a successful query."""
87
+ answer: str
88
+ score: float
89
+ context: str
90
+
91
+ # --- 3. HELPER FUNCTIONS ---
92
+
93
+ def parse_pdf(content: bytes) -> str:
94
+ """Extracts text from PDF bytes."""
95
+ doc = fitz.open(stream=content, filetype="pdf")
96
+ text = "".join(page.get_text() for page in doc)
97
+ return text
98
+
99
+ def parse_image(content: bytes) -> str:
100
+ """Extracts text from image bytes using OCR."""
101
+ image = Image.open(io.BytesIO(content))
102
+ return pytesseract.image_to_string(image)
103
 
 
104
  def chunk_text(text: str, chunk_size: int = 256, overlap: int = 32) -> list[str]:
105
  """Splits text into overlapping chunks of words."""
106
  words = text.split()
107
+ if not words: return []
108
+ return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size - overlap)]
109
+
110
+ def deserialize_index(serialized_index: bytes) -> faiss.Index:
111
+ """Loads a FAISS index from its byte representation."""
112
+ return faiss.read_index(faiss.VectorReader(serialized_index))
113
+
114
+ # --- 4. API ENDPOINTS ---
115
+
116
+ @app.get("/")
117
+ def read_root():
118
+ """Root endpoint for health checks."""
119
+ return {"status": "ok", "message": "Welcome to Universal Data AI"}
120
+
121
+ @app.post("/upload", response_model=UploadResponse)
122
  async def upload_file(file: UploadFile = File(...)):
123
+ """Handles file upload, parsing, and AI indexing."""
124
+ if not embedding_model:
125
+ raise HTTPException(status_code=503, detail="AI models are not available.")
126
+
127
  session_id = str(uuid.uuid4())
128
+ logger.info(f"Upload received for session {session_id}: {file.filename}")
129
  content = await file.read()
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # Step 1: Parse content based on file type
132
+ content_type = file.content_type
133
+ if content_type == "application/pdf": text = parse_pdf(content)
134
+ elif content_type and content_type.startswith("image/"): text = parse_image(content)
135
+ elif content_type == "text/plain": text = content.decode("utf-8")
136
+ else: raise HTTPException(status_code=400, detail=f"Unsupported file type: {content_type}")
137
 
138
+ if not text.strip():
139
+ raise HTTPException(status_code=400, detail="No text could be extracted from the file.")
 
 
140
 
141
+ # Step 2: Chunk the text
142
+ text_chunks = chunk_text(text)
143
+ if not text_chunks:
144
+ raise HTTPException(status_code=400, detail="Document too short to be processed.")
145
+
146
+ # Step 3: Generate embeddings
147
+ embeddings = embedding_model.encode(text_chunks, convert_to_numpy=True).astype('float32')
148
+
149
+ # Step 4: Create and store FAISS index
150
+ index = faiss.IndexFlatL2(embeddings.shape[1])
151
+ index.add(embeddings)
152
 
 
153
  SESSION_DATA[session_id] = {
 
154
  "chunks": text_chunks,
155
+ "index": index.serialize(), # Store the index as bytes
156
  }
157
 
158
+ logger.info(f"Session {session_id} created with {len(text_chunks)} chunks.")
159
+ return {"session_id": session_id, "filename": file.filename, "chunks_created": len(text_chunks)}
160
+
161
+ @app.post("/query/{session_id}", response_model=QueryResponse)
162
+ async def query_session(session_id: str, request: QueryRequest):
163
+ """Answers a question based on the indexed content of a session."""
164
+ if not qa_pipeline or not embedding_model:
165
+ raise HTTPException(status_code=503, detail="AI models are not available.")
166
+
167
+ # Step 1: Retrieve session data
168
+ session = SESSION_DATA.get(session_id)
169
+ if not session:
170
+ raise HTTPException(status_code=404, detail="Session not found.")
171
+
172
+ # Step 2: Find relevant context using vector search
173
+ question_embedding = embedding_model.encode([request.question]).astype('float32')
174
+ index = deserialize_index(session["index"])
175
+
176
+ # Search for the top 3 most relevant chunks
177
+ k = min(3, index.ntotal)
178
+ distances, indices = index.search(question_embedding, k)
179
+
180
+ relevant_chunks = [session["chunks"][i] for i in indices[0]]
181
+ context = " ".join(relevant_chunks)
182
+
183
+ # Step 3: Use the QA model to find the answer within the context
184
+ result = qa_pipeline(question=request.question, context=context)
185
+
186
+ logger.info(f"Query for session {session_id} answered with score: {result['score']:.4f}")
187
  return {
188
+ "answer": result["answer"],
189
+ "score": result["score"],
190
+ "context": context
191
  }