NitinBot001 commited on
Commit
72a6cfc
·
verified ·
1 Parent(s): 95cb046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -23
app.py CHANGED
@@ -6,6 +6,7 @@ from datetime import datetime
6
  import json
7
  import time
8
  from pathlib import Path
 
9
 
10
  from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
11
  from fastapi.middleware.cors import CORSMiddleware
@@ -55,15 +56,17 @@ is_initialized = False
55
  # Configuration
56
  class Config:
57
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
58
- CHUNK_SIZE = 800
59
- CHUNK_OVERLAP = 100
60
- MAX_RETRIES = 3
61
- RATE_LIMIT_DELAY = 1.0
 
 
62
  MODEL_NAME = "gemma-3-27b-it"
63
  EMBEDDING_MODEL = "models/embedding-001"
64
  TEMPERATURE = 0.5
65
- MAX_OUTPUT_TOKENS = 10000
66
- RETRIEVER_K = 15
67
  INDEX_PATH = "faiss_maize_index"
68
  DATA_PATH = "data/maize_data.txt"
69
 
@@ -71,7 +74,7 @@ config = Config()
71
 
72
  # Request/Response Models
73
  class QueryRequest(BaseModel):
74
- query: str = Field(..., min_length=1, max_length=10000)
75
 
76
  class QueryResponse(BaseModel):
77
  answer: str
@@ -103,6 +106,64 @@ def estimate_tokens(text: str) -> int:
103
  """Estimates token count for a given text."""
104
  return len(tokenizer.encode(text))
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Custom Callback Handler
107
  class TokenUsageCallbackHandler(BaseCallbackHandler):
108
  """Callback handler to track token usage in LLM calls."""
@@ -181,26 +242,37 @@ async def initialize_rag_system(api_key: str = None):
181
  chunks = text_splitter.split_documents(documents)
182
  logger.info(f"Document split into {len(chunks)} chunks")
183
 
184
- # Initialize embeddings
 
 
 
 
185
  embeddings = GoogleGenerativeAIEmbeddings(
186
  model=config.EMBEDDING_MODEL,
187
  google_api_key=config.GOOGLE_API_KEY
188
  )
189
 
190
- # Create or load FAISS index
191
  if os.path.exists(config.INDEX_PATH):
192
- vector_store = FAISS.load_local(
193
- config.INDEX_PATH,
194
- embeddings,
195
- allow_dangerous_deserialization=True
196
- )
197
- logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'")
 
 
 
 
 
 
 
198
  else:
199
- vector_store = FAISS.from_documents(chunks, embeddings)
200
  vector_store.save_local(config.INDEX_PATH)
201
  logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'")
202
 
203
- # Initialize LLM
204
  llm = ChatGoogleGenerativeAI(
205
  model=config.MODEL_NAME,
206
  google_api_key=config.GOOGLE_API_KEY,
@@ -217,7 +289,6 @@ You are an expert in maize agriculture. Use the following context ONLY to answer
217
  If there have any query about getting personal information of a person then don't get it and reply full answer accordingly context.
218
  Answer should be concise clear and with easy language.
219
 
220
-
221
  Context:
222
  {context}
223
 
@@ -258,8 +329,11 @@ async def startup_event():
258
  @app.get("/", response_class=HTMLResponse)
259
  async def root():
260
  """Serve the main HTML page."""
261
- with open("static/index.html", "r") as f:
262
- return f.read()
 
 
 
263
 
264
  @app.get("/api/status", response_model=SystemStatus)
265
  async def get_status():
@@ -302,7 +376,7 @@ async def process_query(request: QueryRequest):
302
  if token_callback_handler:
303
  token_callback_handler.last_call_tokens = {}
304
 
305
- # Process query with retry logic
306
  for attempt in range(config.MAX_RETRIES):
307
  try:
308
  result = qa_chain({"query": request.query})
@@ -310,7 +384,11 @@ async def process_query(request: QueryRequest):
310
  except Exception as e:
311
  if attempt == config.MAX_RETRIES - 1:
312
  raise
313
- await asyncio.sleep(config.RATE_LIMIT_DELAY * (attempt + 1))
 
 
 
 
314
 
315
  processing_time = time.time() - start_time
316
 
@@ -349,17 +427,24 @@ async def get_token_stats():
349
  async def upload_document(file: UploadFile = File(...)):
350
  """Upload a new document to replace the existing one."""
351
  try:
 
 
 
 
352
  # Save uploaded file
353
  content = await file.read()
354
  with open(config.DATA_PATH, "wb") as f:
355
  f.write(content)
356
 
 
 
357
  # Reinitialize the system with new data
358
  if config.GOOGLE_API_KEY:
359
  # Remove old index to force recreation
360
  if os.path.exists(config.INDEX_PATH):
361
  import shutil
362
  shutil.rmtree(config.INDEX_PATH)
 
363
 
364
  await initialize_rag_system()
365
  return {"success": True, "message": "Document uploaded and system reinitialized"}
@@ -367,10 +452,22 @@ async def upload_document(file: UploadFile = File(...)):
367
  return {"success": True, "message": "Document uploaded. Please initialize the system."}
368
 
369
  except Exception as e:
 
370
  raise HTTPException(status_code=500, detail=str(e))
371
 
 
 
 
 
 
 
 
 
 
 
372
  # Mount static files
373
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
374
 
375
  if __name__ == "__main__":
376
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import json
7
  import time
8
  from pathlib import Path
9
+ import random
10
 
11
  from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
12
  from fastapi.middleware.cors import CORSMiddleware
 
56
  # Configuration
57
  class Config:
58
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
59
+ CHUNK_SIZE = 500 # Reduced chunk size to create fewer embeddings
60
+ CHUNK_OVERLAP = 50 # Reduced overlap
61
+ MAX_RETRIES = 5 # Increased retries
62
+ RATE_LIMIT_DELAY = 2.0 # Increased delay
63
+ EMBEDDING_BATCH_SIZE = 5 # Process embeddings in small batches
64
+ EMBEDDING_DELAY = 1.5 # Delay between embedding batches
65
  MODEL_NAME = "gemma-3-27b-it"
66
  EMBEDDING_MODEL = "models/embedding-001"
67
  TEMPERATURE = 0.5
68
+ MAX_OUTPUT_TOKENS = 20000
69
+ RETRIEVER_K = 15 # Reduced retrieval count
70
  INDEX_PATH = "faiss_maize_index"
71
  DATA_PATH = "data/maize_data.txt"
72
 
 
74
 
75
  # Request/Response Models
76
  class QueryRequest(BaseModel):
77
+ query: str = Field(..., min_length=1, max_length=100000)
78
 
79
  class QueryResponse(BaseModel):
80
  answer: str
 
106
  """Estimates token count for a given text."""
107
  return len(tokenizer.encode(text))
108
 
109
+ # Rate limiting helper functions
110
+ async def rate_limited_embedding_creation(chunks, embeddings):
111
+ """Create embeddings with rate limiting to avoid API limits."""
112
+ logger.info(f"Creating embeddings for {len(chunks)} chunks with rate limiting...")
113
+
114
+ # Process chunks in smaller batches
115
+ batch_size = config.EMBEDDING_BATCH_SIZE
116
+ all_embeddings = []
117
+
118
+ for i in range(0, len(chunks), batch_size):
119
+ batch = chunks[i:i + batch_size]
120
+ logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} ({len(batch)} chunks)")
121
+
122
+ retry_count = 0
123
+ max_retries = 5
124
+
125
+ while retry_count < max_retries:
126
+ try:
127
+ # Create vector store for this batch
128
+ if i == 0:
129
+ # First batch - create new vector store
130
+ vector_store_batch = FAISS.from_documents(batch, embeddings)
131
+ all_embeddings.append(vector_store_batch)
132
+ else:
133
+ # Subsequent batches - merge with existing
134
+ vector_store_batch = FAISS.from_documents(batch, embeddings)
135
+ all_embeddings.append(vector_store_batch)
136
+
137
+ logger.info(f"Successfully processed batch {i//batch_size + 1}")
138
+ break
139
+
140
+ except Exception as e:
141
+ retry_count += 1
142
+ delay = config.EMBEDDING_DELAY * (2 ** retry_count) + random.uniform(0, 1)
143
+ logger.warning(f"Batch {i//batch_size + 1} failed (attempt {retry_count}): {str(e)}")
144
+ logger.info(f"Waiting {delay:.2f} seconds before retry...")
145
+ await asyncio.sleep(delay)
146
+
147
+ if retry_count >= max_retries:
148
+ raise Exception(f"Failed to process batch after {max_retries} attempts: {str(e)}")
149
+
150
+ # Delay between batches to respect rate limits
151
+ if i + batch_size < len(chunks):
152
+ delay = config.EMBEDDING_DELAY + random.uniform(0.5, 1.0)
153
+ logger.info(f"Waiting {delay:.2f} seconds before next batch...")
154
+ await asyncio.sleep(delay)
155
+
156
+ # Merge all vector stores
157
+ logger.info("Merging all vector store batches...")
158
+ final_vector_store = all_embeddings[0]
159
+
160
+ for i in range(1, len(all_embeddings)):
161
+ final_vector_store.merge_from(all_embeddings[i])
162
+ logger.info(f"Merged batch {i + 1}/{len(all_embeddings)}")
163
+
164
+ logger.info("Successfully created and merged all embeddings")
165
+ return final_vector_store
166
+
167
  # Custom Callback Handler
168
  class TokenUsageCallbackHandler(BaseCallbackHandler):
169
  """Callback handler to track token usage in LLM calls."""
 
242
  chunks = text_splitter.split_documents(documents)
243
  logger.info(f"Document split into {len(chunks)} chunks")
244
 
245
+ # Check if we have too many chunks that might cause rate limiting
246
+ if len(chunks) > 100:
247
+ logger.warning(f"Large number of chunks ({len(chunks)}). Consider increasing chunk_size or reducing document size to avoid rate limits.")
248
+
249
+ # Initialize embeddings with retry logic
250
  embeddings = GoogleGenerativeAIEmbeddings(
251
  model=config.EMBEDDING_MODEL,
252
  google_api_key=config.GOOGLE_API_KEY
253
  )
254
 
255
+ # Create or load FAISS index with rate limiting
256
  if os.path.exists(config.INDEX_PATH):
257
+ try:
258
+ vector_store = FAISS.load_local(
259
+ config.INDEX_PATH,
260
+ embeddings,
261
+ allow_dangerous_deserialization=True
262
+ )
263
+ logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'")
264
+ except Exception as e:
265
+ logger.warning(f"Failed to load existing index: {str(e)}")
266
+ logger.info("Creating new index...")
267
+ vector_store = await rate_limited_embedding_creation(chunks, embeddings)
268
+ vector_store.save_local(config.INDEX_PATH)
269
+ logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'")
270
  else:
271
+ vector_store = await rate_limited_embedding_creation(chunks, embeddings)
272
  vector_store.save_local(config.INDEX_PATH)
273
  logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'")
274
 
275
+ # Initialize LLM with retry and rate limiting
276
  llm = ChatGoogleGenerativeAI(
277
  model=config.MODEL_NAME,
278
  google_api_key=config.GOOGLE_API_KEY,
 
289
  If there have any query about getting personal information of a person then don't get it and reply full answer accordingly context.
290
  Answer should be concise clear and with easy language.
291
 
 
292
  Context:
293
  {context}
294
 
 
329
  @app.get("/", response_class=HTMLResponse)
330
  async def root():
331
  """Serve the main HTML page."""
332
+ try:
333
+ with open("static/index.html", "r") as f:
334
+ return f.read()
335
+ except FileNotFoundError:
336
+ return "<h1>Static files not found. Please ensure static/index.html exists.</h1>"
337
 
338
  @app.get("/api/status", response_model=SystemStatus)
339
  async def get_status():
 
376
  if token_callback_handler:
377
  token_callback_handler.last_call_tokens = {}
378
 
379
+ # Process query with retry logic and exponential backoff
380
  for attempt in range(config.MAX_RETRIES):
381
  try:
382
  result = qa_chain({"query": request.query})
 
384
  except Exception as e:
385
  if attempt == config.MAX_RETRIES - 1:
386
  raise
387
+
388
+ delay = config.RATE_LIMIT_DELAY * (2 ** attempt) + random.uniform(0, 1)
389
+ logger.warning(f"Query attempt {attempt + 1} failed: {str(e)}")
390
+ logger.info(f"Retrying in {delay:.2f} seconds...")
391
+ await asyncio.sleep(delay)
392
 
393
  processing_time = time.time() - start_time
394
 
 
427
  async def upload_document(file: UploadFile = File(...)):
428
  """Upload a new document to replace the existing one."""
429
  try:
430
+ # Validate file
431
+ if not file.filename.endswith('.txt'):
432
+ raise HTTPException(status_code=400, detail="Only .txt files are supported")
433
+
434
  # Save uploaded file
435
  content = await file.read()
436
  with open(config.DATA_PATH, "wb") as f:
437
  f.write(content)
438
 
439
+ logger.info(f"Uploaded new document: {file.filename}")
440
+
441
  # Reinitialize the system with new data
442
  if config.GOOGLE_API_KEY:
443
  # Remove old index to force recreation
444
  if os.path.exists(config.INDEX_PATH):
445
  import shutil
446
  shutil.rmtree(config.INDEX_PATH)
447
+ logger.info("Removed old FAISS index")
448
 
449
  await initialize_rag_system()
450
  return {"success": True, "message": "Document uploaded and system reinitialized"}
 
452
  return {"success": True, "message": "Document uploaded. Please initialize the system."}
453
 
454
  except Exception as e:
455
+ logger.error(f"Error uploading document: {str(e)}")
456
  raise HTTPException(status_code=500, detail=str(e))
457
 
458
+ # Health check endpoint
459
+ @app.get("/health")
460
+ async def health_check():
461
+ """Health check endpoint."""
462
+ return {
463
+ "status": "healthy",
464
+ "timestamp": datetime.now().isoformat(),
465
+ "system_initialized": is_initialized
466
+ }
467
+
468
  # Mount static files
469
+ if os.path.exists("static"):
470
+ app.mount("/static", StaticFiles(directory="static"), name="static")
471
 
472
  if __name__ == "__main__":
473
  uvicorn.run(app, host="0.0.0.0", port=7860)