Kalpokoch commited on
Commit
01eba2b
·
1 Parent(s): e7409fa

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -43
app.py CHANGED
@@ -1,19 +1,15 @@
1
  #
2
  # ---------------- Universal Data AI ----------------
3
  #
4
- # Final app.py script
5
- # Combines:
6
- # 1. File Upload & Parsing (PDF, Image, Text)
7
- # 2. Text Chunking
8
- # 3. Vector Embedding & FAISS Indexing
9
- # 4. A Query Endpoint for Question Answering
10
  #
11
  # Last updated: August 8, 2025
12
  #
13
 
14
  import logging
15
  import uuid
16
- import io
17
 
18
  # FastAPI & Pydantic
19
  from fastapi import FastAPI, UploadFile, File, HTTPException
@@ -33,33 +29,26 @@ from transformers import pipeline
33
 
34
  # --- 1. INITIAL SETUP & MODEL LOADING ---
35
 
36
- # Configure logging to see outputs in Hugging Face Space logs
37
  logging.basicConfig(level=logging.INFO)
38
  logger = logging.getLogger(__name__)
39
 
40
- # Initialize FastAPI app
41
  app = FastAPI(
42
  title="Universal Data AI",
43
  description="Ephemeral data analysis tool with in-memory vector search.",
44
- version="1.0.0",
45
  )
46
 
47
- # Add CORS middleware to allow frontend requests
48
  app.add_middleware(
49
  CORSMiddleware,
50
- allow_origins=["*"], # Allow all for simplicity, can be restricted to your frontend URL
51
  allow_credentials=True,
52
  allow_methods=["*"],
53
  allow_headers=["*"],
54
  )
55
 
56
- # Load AI models on startup
57
- # This can take a moment when the app first boots.
58
  try:
59
  logger.info("Loading AI models...")
60
- # Model for creating vector embeddings
61
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
62
- # Pipeline for question-answering
63
  qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
64
  logger.info("AI models loaded successfully.")
65
  except Exception as e:
@@ -67,23 +56,19 @@ except Exception as e:
67
  embedding_model = None
68
  qa_pipeline = None
69
 
70
- # In-memory dictionary to act as our temporary session database
71
  SESSION_DATA = {}
72
 
73
  # --- 2. DATA MODELS ---
74
 
75
  class QueryRequest(BaseModel):
76
- """Defines the request body for the /query endpoint."""
77
  question: str
78
 
79
  class UploadResponse(BaseModel):
80
- """Defines the response for a successful file upload."""
81
  session_id: str
82
  filename: str
83
  chunks_created: int
84
 
85
  class QueryResponse(BaseModel):
86
- """Defines the response for a successful query."""
87
  answer: str
88
  score: float
89
  context: str
@@ -91,36 +76,40 @@ class QueryResponse(BaseModel):
91
  # --- 3. HELPER FUNCTIONS ---
92
 
93
  def parse_pdf(content: bytes) -> str:
94
- """Extracts text from PDF bytes."""
95
  doc = fitz.open(stream=content, filetype="pdf")
96
- text = "".join(page.get_text() for page in doc)
97
- return text
98
 
99
  def parse_image(content: bytes) -> str:
100
- """Extracts text from image bytes using OCR."""
101
  image = Image.open(io.BytesIO(content))
102
  return pytesseract.image_to_string(image)
103
 
104
  def chunk_text(text: str, chunk_size: int = 256, overlap: int = 32) -> list[str]:
105
- """Splits text into overlapping chunks of words."""
106
  words = text.split()
107
  if not words: return []
108
  return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size - overlap)]
109
 
 
110
  def deserialize_index(serialized_index: bytes) -> faiss.Index:
111
- """Loads a FAISS index from its byte representation."""
112
- return faiss.read_index(faiss.VectorReader(serialized_index))
 
 
 
 
 
 
 
 
 
113
 
114
  # --- 4. API ENDPOINTS ---
115
 
116
  @app.get("/")
117
  def read_root():
118
- """Root endpoint for health checks."""
119
  return {"status": "ok", "message": "Welcome to Universal Data AI"}
120
 
121
  @app.post("/upload", response_model=UploadResponse)
122
  async def upload_file(file: UploadFile = File(...)):
123
- """Handles file upload, parsing, and AI indexing."""
124
  if not embedding_model:
125
  raise HTTPException(status_code=503, detail="AI models are not available.")
126
 
@@ -128,7 +117,6 @@ async def upload_file(file: UploadFile = File(...)):
128
  logger.info(f"Upload received for session {session_id}: {file.filename}")
129
  content = await file.read()
130
 
131
- # Step 1: Parse content based on file type
132
  content_type = file.content_type
133
  if content_type == "application/pdf": text = parse_pdf(content)
134
  elif content_type and content_type.startswith("image/"): text = parse_image(content)
@@ -138,22 +126,28 @@ async def upload_file(file: UploadFile = File(...)):
138
  if not text.strip():
139
  raise HTTPException(status_code=400, detail="No text could be extracted from the file.")
140
 
141
- # Step 2: Chunk the text
142
  text_chunks = chunk_text(text)
143
  if not text_chunks:
144
  raise HTTPException(status_code=400, detail="Document too short to be processed.")
145
 
146
- # Step 3: Generate embeddings
147
  embeddings = embedding_model.encode(text_chunks, convert_to_numpy=True).astype('float32')
148
-
149
- # Step 4: Create and store FAISS index
150
  index = faiss.IndexFlatL2(embeddings.shape[1])
151
  index.add(embeddings)
152
 
 
 
 
 
 
 
 
 
 
 
 
153
  SESSION_DATA[session_id] = {
154
- "filename": file.filename,
155
  "chunks": text_chunks,
156
- "index": faiss.write_index_buf(index), # Store the index as bytes
157
  }
158
 
159
  logger.info(f"Session {session_id} created with {len(text_chunks)} chunks.")
@@ -161,32 +155,27 @@ async def upload_file(file: UploadFile = File(...)):
161
 
162
  @app.post("/query/{session_id}", response_model=QueryResponse)
163
  async def query_session(session_id: str, request: QueryRequest):
164
- """Answers a question based on the indexed content of a session."""
165
  if not qa_pipeline or not embedding_model:
166
  raise HTTPException(status_code=503, detail="AI models are not available.")
167
 
168
- # Step 1: Retrieve session data
169
  session = SESSION_DATA.get(session_id)
170
  if not session:
171
  raise HTTPException(status_code=404, detail="Session not found.")
172
 
173
- # Step 2: Find relevant context using vector search
174
- question_embedding = embedding_model.encode([request.question]).astype('float32')
175
  index = deserialize_index(session["index"])
 
176
 
177
- # Search for the top 3 most relevant chunks
178
  k = min(3, index.ntotal)
179
  distances, indices = index.search(question_embedding, k)
180
 
181
  relevant_chunks = [session["chunks"][i] for i in indices[0]]
182
  context = " ".join(relevant_chunks)
183
 
184
- # Step 3: Use the QA model to find the answer within the context
185
  result = qa_pipeline(question=request.question, context=context)
186
 
187
  logger.info(f"Query for session {session_id} answered with score: {result['score']:.4f}")
188
  return {
189
  "answer": result["answer"],
190
  "score": result["score"],
191
- "context": context
192
  }
 
1
  #
2
  # ---------------- Universal Data AI ----------------
3
  #
4
+ # Final app.py script (v3) with robust FAISS I/O
5
+ # Corrects previous serialization errors.
 
 
 
 
6
  #
7
  # Last updated: August 8, 2025
8
  #
9
 
10
  import logging
11
  import uuid
12
+ import io # Ensure io is imported
13
 
14
  # FastAPI & Pydantic
15
  from fastapi import FastAPI, UploadFile, File, HTTPException
 
29
 
30
  # --- 1. INITIAL SETUP & MODEL LOADING ---
31
 
 
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger(__name__)
34
 
 
35
  app = FastAPI(
36
  title="Universal Data AI",
37
  description="Ephemeral data analysis tool with in-memory vector search.",
38
+ version="1.0.1", # Version bump
39
  )
40
 
 
41
  app.add_middleware(
42
  CORSMiddleware,
43
+ allow_origins=["*"],
44
  allow_credentials=True,
45
  allow_methods=["*"],
46
  allow_headers=["*"],
47
  )
48
 
 
 
49
  try:
50
  logger.info("Loading AI models...")
 
51
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
52
  qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
53
  logger.info("AI models loaded successfully.")
54
  except Exception as e:
 
56
  embedding_model = None
57
  qa_pipeline = None
58
 
 
59
  SESSION_DATA = {}
60
 
61
  # --- 2. DATA MODELS ---
62
 
63
  class QueryRequest(BaseModel):
 
64
  question: str
65
 
66
  class UploadResponse(BaseModel):
 
67
  session_id: str
68
  filename: str
69
  chunks_created: int
70
 
71
  class QueryResponse(BaseModel):
 
72
  answer: str
73
  score: float
74
  context: str
 
76
  # --- 3. HELPER FUNCTIONS ---
77
 
78
  def parse_pdf(content: bytes) -> str:
 
79
  doc = fitz.open(stream=content, filetype="pdf")
80
+ return "".join(page.get_text() for page in doc)
 
81
 
82
  def parse_image(content: bytes) -> str:
 
83
  image = Image.open(io.BytesIO(content))
84
  return pytesseract.image_to_string(image)
85
 
86
  def chunk_text(text: str, chunk_size: int = 256, overlap: int = 32) -> list[str]:
 
87
  words = text.split()
88
  if not words: return []
89
  return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size - overlap)]
90
 
91
+ # --- THIS FUNCTION IS CORRECTED ---
92
  def deserialize_index(serialized_index: bytes) -> faiss.Index:
93
+ """
94
+ Loads a FAISS index from its byte representation using a robust method.
95
+ """
96
+ try:
97
+ bio = io.BytesIO(serialized_index)
98
+ # Use PyCallbackIOReader to read from the in-memory binary stream
99
+ reader = faiss.PyCallbackIOReader(bio.read)
100
+ return faiss.read_index(reader)
101
+ except Exception as e:
102
+ logger.error(f"Failed to deserialize FAISS index: {e}")
103
+ raise
104
 
105
  # --- 4. API ENDPOINTS ---
106
 
107
  @app.get("/")
108
  def read_root():
 
109
  return {"status": "ok", "message": "Welcome to Universal Data AI"}
110
 
111
  @app.post("/upload", response_model=UploadResponse)
112
  async def upload_file(file: UploadFile = File(...)):
 
113
  if not embedding_model:
114
  raise HTTPException(status_code=503, detail="AI models are not available.")
115
 
 
117
  logger.info(f"Upload received for session {session_id}: {file.filename}")
118
  content = await file.read()
119
 
 
120
  content_type = file.content_type
121
  if content_type == "application/pdf": text = parse_pdf(content)
122
  elif content_type and content_type.startswith("image/"): text = parse_image(content)
 
126
  if not text.strip():
127
  raise HTTPException(status_code=400, detail="No text could be extracted from the file.")
128
 
 
129
  text_chunks = chunk_text(text)
130
  if not text_chunks:
131
  raise HTTPException(status_code=400, detail="Document too short to be processed.")
132
 
 
133
  embeddings = embedding_model.encode(text_chunks, convert_to_numpy=True).astype('float32')
 
 
134
  index = faiss.IndexFlatL2(embeddings.shape[1])
135
  index.add(embeddings)
136
 
137
+ # --- THIS SECTION IS CORRECTED ---
138
+ try:
139
+ # Use PyCallbackIOWriter to write the index to an in-memory binary stream
140
+ bio = io.BytesIO()
141
+ writer = faiss.PyCallbackIOWriter(bio.write)
142
+ faiss.write_index(index, writer)
143
+ serialized_index = bio.getvalue()
144
+ except Exception as e:
145
+ logger.error(f"Failed to serialize FAISS index: {e}")
146
+ raise HTTPException(status_code=500, detail="Failed to create document index.")
147
+
148
  SESSION_DATA[session_id] = {
 
149
  "chunks": text_chunks,
150
+ "index": serialized_index, # Store the index as bytes
151
  }
152
 
153
  logger.info(f"Session {session_id} created with {len(text_chunks)} chunks.")
 
155
 
156
  @app.post("/query/{session_id}", response_model=QueryResponse)
157
  async def query_session(session_id: str, request: QueryRequest):
 
158
  if not qa_pipeline or not embedding_model:
159
  raise HTTPException(status_code=503, detail="AI models are not available.")
160
 
 
161
  session = SESSION_DATA.get(session_id)
162
  if not session:
163
  raise HTTPException(status_code=404, detail="Session not found.")
164
 
 
 
165
  index = deserialize_index(session["index"])
166
+ question_embedding = embedding_model.encode([request.question]).astype('float32')
167
 
 
168
  k = min(3, index.ntotal)
169
  distances, indices = index.search(question_embedding, k)
170
 
171
  relevant_chunks = [session["chunks"][i] for i in indices[0]]
172
  context = " ".join(relevant_chunks)
173
 
 
174
  result = qa_pipeline(question=request.question, context=context)
175
 
176
  logger.info(f"Query for session {session_id} answered with score: {result['score']:.4f}")
177
  return {
178
  "answer": result["answer"],
179
  "score": result["score"],
180
+ "context": context,
181
  }