NitinBot001 commited on
Commit
edf5647
·
verified ·
1 Parent(s): 8342615

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -491
app.py CHANGED
@@ -1,16 +1,22 @@
1
  import os
 
 
 
 
 
2
  import time
3
- import gradio as gr
 
 
 
 
 
 
4
  import uvicorn
5
- from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
6
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
- from pydantic import BaseModel
8
- from typing import Optional, Dict, Any
9
- import threading
10
- import logging
11
- from langchain._community.document_loaders import TextLoader
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain_community.vectorstores import FAISS
14
  from langchain.chains import RetrievalQA
15
  from langchain.prompts import PromptTemplate
16
  from langchain.callbacks.base import BaseCallbackHandler
@@ -18,39 +24,99 @@ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmb
18
  import tiktoken
19
 
20
  # Configure logging
21
- logging.basicConfig(level=logging.INFO)
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
- # --- Configuration ---
25
- CHUNK_SIZE = 800
26
- CHUNK_OVERLAP = 100
27
- MAX_TOKENS = 512
28
- TEMPERATURE = 0.5
29
- RETRIEVAL_K = 5
 
 
 
 
 
 
 
 
 
30
 
31
- # --- Token Counting Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
  tokenizer = tiktoken.get_encoding("cl100k_base")
34
  except:
35
- print("Tiktoken encoder 'cl100k_base' not found. Using basic split().")
36
  tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})()
37
 
38
- def estimate_tokens(text):
39
  """Estimates token count for a given text."""
40
  return len(tokenizer.encode(text))
41
 
42
- # Custom Callback Handler to track LLM token usage
43
  class TokenUsageCallbackHandler(BaseCallbackHandler):
44
  """Callback handler to track token usage in LLM calls."""
 
45
  def __init__(self):
46
  super().__init__()
47
- self.reset_counters()
48
-
49
- def reset_counters(self):
50
  self.total_prompt_tokens = 0
51
  self.total_completion_tokens = 0
52
  self.total_llm_calls = 0
53
-
 
54
  def on_llm_end(self, response, **kwargs):
55
  """Collect token usage from the LLM response."""
56
  self.total_llm_calls += 1
@@ -63,120 +129,92 @@ class TokenUsageCallbackHandler(BaseCallbackHandler):
63
 
64
  self.total_prompt_tokens += prompt_tokens
65
  self.total_completion_tokens += completion_tokens
66
-
67
- def get_total_tokens(self):
68
- """Returns the total prompt and completion tokens."""
 
 
 
 
 
 
 
 
 
 
69
  return {
70
  "total_prompt_tokens": self.total_prompt_tokens,
71
  "total_completion_tokens": self.total_completion_tokens,
72
- "total_llm_tokens": self.total_prompt_tokens + self.total_completion_tokens,
73
- "total_llm_calls": self.total_llm_calls
74
  }
75
 
76
- # --- Pydantic Models for API ---
77
- class InitializeRequest(BaseModel):
78
- api_key: str
79
- document_content: Optional[str] = None
80
-
81
- class QueryRequest(BaseModel):
82
- query: str
83
- api_key: str
84
-
85
- class InitializeResponse(BaseModel):
86
- success: bool
87
- message: str
88
- chunks: Optional[int] = None
89
- estimated_tokens: Optional[int] = None
90
-
91
- class QueryResponse(BaseModel):
92
- success: bool
93
- answer: str
94
- response_time: float
95
- query_tokens: int
96
- llm_tokens: Dict[str, int]
97
- session_stats: Dict[str, int]
98
-
99
- class StatsResponse(BaseModel):
100
- total_queries: int
101
- total_embedding_tokens: int
102
- total_llm_tokens: int
103
- total_llm_calls: int
104
- initialization_complete: bool
105
-
106
- # --- Global Variables ---
107
- class RAGSystem:
108
- def __init__(self):
109
- self.vector_store = None
110
- self.qa_chain = None
111
- self.token_callback_handler = TokenUsageCallbackHandler()
112
- self.session_stats = {
113
- "total_queries": 0,
114
- "total_embedding_tokens": 0,
115
- "initialization_complete": False
116
- }
117
- self.current_api_key = None
118
-
119
- # Global RAG system instance
120
- rag_system = RAGSystem()
121
-
122
- def initialize_rag_system(api_key, file_content=None):
123
- """Initialize the RAG system with API key and optional file content."""
124
- global rag_system
125
 
126
  try:
127
- # Set API key
128
- os.environ["GOOGLE_API_KEY"] = api_key
129
- rag_system.current_api_key = api_key
130
-
131
- # Initialize embeddings
132
- embeddings = GoogleGenerativeAIEmbeddings(
133
- model="models/embedding-001",
134
- google_api_key=api_key
135
- )
136
 
137
- # Initialize LLM
138
- llm = ChatGoogleGenerativeAI(
139
- model="gemini-1.5-flash",
140
- google_api_key=api_key,
141
- temperature=TEMPERATURE,
142
- max_tokens=MAX_TOKENS,
143
- callbacks=[rag_system.token_callback_handler],
144
- verbose=False
145
- )
146
 
147
- # Load or use default document
148
- if file_content:
149
- # Save uploaded file content
150
- with open("uploaded_document.txt", "w", encoding="utf-8") as f:
151
- f.write(file_content)
152
- loader = TextLoader("uploaded_document.txt")
153
- else:
154
- # Check if default maize_data.txt exists
155
- if os.path.exists("maize_data.txt"):
156
- loader = TextLoader("maize_data.txt")
157
- else:
158
- return "❌ No document found. Please upload a file or ensure maize_data.txt exists."
159
 
160
  # Load and split documents
 
 
 
 
161
  documents = loader.load()
 
162
  text_splitter = RecursiveCharacterTextSplitter(
163
- chunk_size=CHUNK_SIZE,
164
- chunk_overlap=CHUNK_OVERLAP
165
  )
166
  chunks = text_splitter.split_documents(documents)
 
167
 
168
- # Estimate embedding tokens
169
- initial_embedding_tokens = sum(estimate_tokens(chunk.page_content) for chunk in chunks)
170
- rag_system.session_stats["total_embedding_tokens"] = initial_embedding_tokens
 
 
171
 
172
- # Create vector store
173
- rag_system.vector_store = FAISS.from_documents(chunks, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Create prompt template
176
  prompt_template = PromptTemplate(
177
  input_variables=["context", "question"],
178
  template="""
179
- You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully. If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question.".
 
180
 
181
  Context:
182
  {context}
@@ -187,415 +225,150 @@ Answer:"""
187
  )
188
 
189
  # Set up QA chain
190
- rag_system.qa_chain = RetrievalQA.from_chain_type(
191
  llm=llm,
192
  chain_type="stuff",
193
- retriever=rag_system.vector_store.as_retriever(search_kwargs={"k": RETRIEVAL_K}),
194
  chain_type_kwargs={"prompt": prompt_template},
195
- callbacks=[rag_system.token_callback_handler],
196
  return_source_documents=True
197
  )
198
 
199
- rag_system.session_stats["initialization_complete"] = True
200
-
201
- return f"✅ RAG system initialized successfully!\n📄 Document processed: {len(chunks)} chunks\n🔢 Estimated embedding tokens: ~{initial_embedding_tokens}"
202
-
203
- except Exception as e:
204
- logger.error(f"Initialization failed: {str(e)}")
205
- return f"❌ Initialization failed: {str(e)}"
206
-
207
- def process_query(query, api_key):
208
- """Process a user query through the RAG system."""
209
- global rag_system
210
-
211
- if not api_key:
212
- return "❌ Please provide a Google API key first.", ""
213
-
214
- if not rag_system.qa_chain:
215
- return "❌ RAG system not initialized. Please initialize first.", ""
216
-
217
- if not query.strip():
218
- return "❌ Please enter a question.", ""
219
-
220
- try:
221
- # Estimate query embedding tokens
222
- query_tokens = estimate_tokens(query)
223
- rag_system.session_stats["total_embedding_tokens"] += query_tokens
224
- rag_system.session_stats["total_queries"] += 1
225
-
226
- # Process query
227
- start_time = time.time()
228
- result = rag_system.qa_chain({"query": query})
229
- end_time = time.time()
230
-
231
- # Get token usage
232
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
233
-
234
- # Format response
235
- answer = result['result']
236
-
237
- # Create stats summary
238
- stats = f"""
239
- 📊 **Query Statistics:**
240
- - Response time: {end_time - start_time:.2f} seconds
241
- - Query tokens (estimated): ~{query_tokens}
242
- - LLM tokens (this query): Prompt: {llm_tokens['total_prompt_tokens']}, Completion: {llm_tokens['total_completion_tokens']}
243
-
244
- 📈 **Session Statistics:**
245
- - Total queries: {rag_system.session_stats['total_queries']}
246
- - Total embedding tokens: ~{rag_system.session_stats['total_embedding_tokens']}
247
- - Total LLM calls: {llm_tokens['total_llm_calls']}
248
- - Total LLM tokens: {llm_tokens['total_llm_tokens']}
249
- """
250
-
251
- return answer, stats
252
 
253
  except Exception as e:
254
- logger.error(f"Error processing query: {str(e)}")
255
- return f"❌ Error processing query: {str(e)}", ""
256
-
257
- def upload_file_and_initialize(api_key, file):
258
- """Handle file upload and system initialization."""
259
- if not api_key:
260
- return "❌ Please provide a Google API key first."
261
-
262
- if file is None:
263
- return initialize_rag_system(api_key)
264
-
265
- try:
266
- # Read uploaded file
267
- file_content = file.decode('utf-8')
268
- return initialize_rag_system(api_key, file_content)
269
- except Exception as e:
270
- return f"❌ Error reading uploaded file: {str(e)}"
271
-
272
- def reset_session():
273
- """Reset the session statistics."""
274
- global rag_system
275
- rag_system.token_callback_handler.reset_counters()
276
- rag_system.session_stats = {
277
- "total_queries": 0,
278
- "total_embedding_tokens": 0,
279
- "initialization_complete": False
280
- }
281
- return "🔄 Session statistics reset."
282
-
283
- # --- FastAPI Setup ---
284
- app = FastAPI(
285
- title="Maize RAG Q&A System API",
286
- description="API for the Maize Agriculture RAG Q&A System",
287
- version="1.0.0"
288
- )
289
-
290
- # Optional: Add API key authentication for API endpoints
291
- security = HTTPBearer(auto_error=False)
292
-
293
- async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
294
- """Extract API key from Authorization header (optional)"""
295
- if credentials:
296
- return credentials.credentials
297
- return None
298
-
299
- # --- API Endpoints ---
300
-
301
- @app.get("/")
302
  async def root():
303
- """Root endpoint"""
304
- return {"message": "Maize RAG Q&A System API", "status": "running"}
305
-
306
- @app.get("/health")
307
- async def health_check():
308
- """Health check endpoint"""
309
- return {
310
- "status": "healthy",
311
- "system_initialized": rag_system.session_stats["initialization_complete"]
312
- }
 
 
 
 
 
 
313
 
314
- @app.post("/initialize", response_model=InitializeResponse)
315
  async def initialize_system(request: InitializeRequest):
316
- """Initialize the RAG system"""
317
  try:
318
- result = initialize_rag_system(request.api_key, request.document_content)
319
-
320
- if "" in result:
321
- # Parse successful result
322
- lines = result.split('\n')
323
- chunks = None
324
- tokens = None
325
-
326
- for line in lines:
327
- if "chunks" in line:
328
- chunks = int(line.split(': ')[1].split(' ')[0])
329
- elif "tokens" in line:
330
- tokens = int(line.split('~')[1])
331
-
332
- return InitializeResponse(
333
- success=True,
334
- message=result,
335
- chunks=chunks,
336
- estimated_tokens=tokens
337
- )
338
- else:
339
- return InitializeResponse(
340
- success=False,
341
- message=result
342
- )
343
-
344
  except Exception as e:
345
- logger.error(f"API initialization error: {str(e)}")
346
  raise HTTPException(status_code=500, detail=str(e))
347
 
348
- @app.post("/query", response_model=QueryResponse)
349
- async def query_system(request: QueryRequest):
350
- """Query the RAG system"""
 
 
 
 
 
 
351
  try:
352
- if not rag_system.session_stats["initialization_complete"]:
353
- raise HTTPException(status_code=400, detail="System not initialized")
354
 
355
- # Estimate query embedding tokens
356
- query_tokens = estimate_tokens(request.query)
357
- rag_system.session_stats["total_embedding_tokens"] += query_tokens
358
- rag_system.session_stats["total_queries"] += 1
359
 
360
- # Process query
361
- start_time = time.time()
362
- result = rag_system.qa_chain({"query": request.query})
363
- end_time = time.time()
 
 
 
 
 
364
 
365
- # Get token usage
366
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
367
 
368
- response_time = end_time - start_time
 
 
 
 
 
 
 
 
 
369
 
370
  return QueryResponse(
371
- success=True,
372
  answer=result['result'],
373
- response_time=response_time,
374
- query_tokens=query_tokens,
375
- llm_tokens=llm_tokens,
376
- session_stats=rag_system.session_stats
377
  )
378
-
379
  except Exception as e:
380
- logger.error(f"API query error: {str(e)}")
381
  raise HTTPException(status_code=500, detail=str(e))
382
 
383
- @app.get("/stats", response_model=StatsResponse)
384
- async def get_stats():
385
- """Get current session statistics"""
386
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
 
387
 
388
- return StatsResponse(
389
- total_queries=rag_system.session_stats["total_queries"],
390
- total_embedding_tokens=rag_system.session_stats["total_embedding_tokens"],
391
- total_llm_tokens=llm_tokens["total_llm_tokens"],
392
- total_llm_calls=llm_tokens["total_llm_calls"],
393
- initialization_complete=rag_system.session_stats["initialization_complete"]
394
- )
395
 
396
- @app.post("/reset")
397
- async def reset_system():
398
- """Reset session statistics"""
399
- reset_session()
400
- return {"message": "Session reset successfully"}
401
-
402
- @app.post("/upload-document")
403
- async def upload_document(
404
- file: UploadFile = File(...),
405
- api_key: str = None
406
- ):
407
- """Upload a document and initialize the system"""
408
  try:
409
- if not api_key:
410
- raise HTTPException(status_code=400, detail="API key required")
411
-
412
- # Read uploaded file
413
  content = await file.read()
414
- file_content = content.decode('utf-8')
415
-
416
- # Initialize system with uploaded content
417
- result = initialize_rag_system(api_key, file_content)
418
-
419
- if "✅" in result:
420
- return {"success": True, "message": result}
 
 
 
 
 
421
  else:
422
- return {"success": False, "message": result}
423
-
424
  except Exception as e:
425
- logger.error(f"Document upload error: {str(e)}")
426
  raise HTTPException(status_code=500, detail=str(e))
427
 
428
- # Create Gradio interface
429
- def create_interface():
430
- with gr.Blocks(title="Maize RAG Q&A System", theme=gr.themes.Soft()) as demo:
431
- gr.Markdown("""
432
- # 🌽 Maize Agriculture RAG Q&A System
433
-
434
- This system uses Retrieval-Augmented Generation (RAG) to answer questions about maize agriculture.
435
- Upload your own document or use the default maize dataset.
436
- """)
437
-
438
- with gr.Row():
439
- with gr.Column(scale=2):
440
- api_key_input = gr.Textbox(
441
- label="🔑 Google API Key",
442
- placeholder="Enter your Google Generative AI API key",
443
- type="password",
444
- info="Get your API key from Google AI Studio"
445
- )
446
-
447
- with gr.Column(scale=1):
448
- reset_btn = gr.Button("🔄 Reset Session", variant="secondary")
449
-
450
- with gr.Row():
451
- with gr.Column():
452
- file_upload = gr.File(
453
- label="📁 Upload Document (Optional)",
454
- file_types=[".txt"],
455
- info="Upload a text file or use the default maize dataset"
456
- )
457
-
458
- init_btn = gr.Button("🚀 Initialize RAG System", variant="primary")
459
- init_output = gr.Textbox(
460
- label="📋 Initialization Status",
461
- lines=3,
462
- interactive=False
463
- )
464
-
465
- gr.Markdown("## 💬 Ask Questions")
466
-
467
- with gr.Row():
468
- with gr.Column(scale=3):
469
- query_input = gr.Textbox(
470
- label="❓ Your Question",
471
- placeholder="Ask something about maize agriculture...",
472
- lines=2
473
- )
474
-
475
- # Sample questions
476
- sample_questions = [
477
- "What are the main pests affecting maize crops?",
478
- "How should maize be irrigated?",
479
- "What is the ideal soil type for maize?",
480
- "What are the nutritional requirements of maize?",
481
- "When is the best time to harvest maize?"
482
- ]
483
-
484
- gr.Examples(
485
- examples=sample_questions,
486
- inputs=query_input,
487
- label="💡 Sample Questions"
488
- )
489
-
490
- with gr.Column(scale=1):
491
- submit_btn = gr.Button("🔍 Ask", variant="primary")
492
-
493
- with gr.Row():
494
- with gr.Column(scale=2):
495
- answer_output = gr.Textbox(
496
- label="🤖 Answer",
497
- lines=6,
498
- interactive=False
499
- )
500
-
501
- with gr.Column(scale=1):
502
- stats_output = gr.Markdown(
503
- label="📊 Statistics",
504
- value="Statistics will appear here after queries."
505
- )
506
-
507
- # Event handlers
508
- init_btn.click(
509
- upload_file_and_initialize,
510
- inputs=[api_key_input, file_upload],
511
- outputs=init_output
512
- )
513
-
514
- submit_btn.click(
515
- process_query,
516
- inputs=[query_input, api_key_input],
517
- outputs=[answer_output, stats_output]
518
- )
519
-
520
- query_input.submit(
521
- process_query,
522
- inputs=[query_input, api_key_input],
523
- outputs=[answer_output, stats_output]
524
- )
525
-
526
- reset_btn.click(
527
- reset_session,
528
- outputs=init_output
529
- )
530
-
531
- gr.Markdown("""
532
- ## 📝 Instructions:
533
- 1. **Enter your Google API Key** (required)
534
- 2. **Upload a document** (optional - uses default maize dataset if not provided)
535
- 3. **Initialize the RAG system** by clicking "Initialize RAG System"
536
- 4. **Ask questions** about the document content
537
- 5. **View statistics** to monitor token usage and costs
538
-
539
- ## 💰 Cost Information:
540
- - **Gemini 1.5 Flash**: Input: $0.075/1M tokens, Output: $0.30/1M tokens
541
- - **Embedding Model**: $0.025/1M tokens
542
-
543
- Token usage is estimated and displayed for cost tracking.
544
- """)
545
-
546
- return demo
547
-
548
- # Create and launch the interface
549
- def run_gradio():
550
- """Run Gradio interface"""
551
- demo = create_interface()
552
- demo.launch(
553
- server_name="0.0.0.0",
554
- server_port=7860,
555
- show_error=True,
556
- quiet=True # Reduce Gradio logs in combined mode
557
- )
558
-
559
- def run_fastapi():
560
- """Run FastAPI server"""
561
- uvicorn.run(
562
- app,
563
- host="0.0.0.0",
564
- port=8000,
565
- log_level="info"
566
- )
567
 
568
  if __name__ == "__main__":
569
- import sys
570
-
571
- if len(sys.argv) > 1:
572
- mode = sys.argv[1]
573
-
574
- if mode == "api":
575
- # Run only FastAPI
576
- print("Starting FastAPI server on port 8000...")
577
- run_fastapi()
578
- elif mode == "gradio":
579
- # Run only Gradio
580
- print("Starting Gradio interface on port 7860...")
581
- run_gradio()
582
- elif mode == "both":
583
- # Run both servers
584
- print("Starting both FastAPI (port 8000) and Gradio (port 7860)...")
585
-
586
- # Start FastAPI in a separate thread
587
- fastapi_thread = threading.Thread(target=run_fastapi)
588
- fastapi_thread.daemon = True
589
- fastapi_thread.start()
590
-
591
- # Start Gradio in main thread
592
- time.sleep(2) # Give FastAPI time to start
593
- run_gradio()
594
- else:
595
- print("Usage: python app.py [api|gradio|both]")
596
- print("Default: gradio only")
597
- run_gradio()
598
- else:
599
- # Default: run only Gradio (for Hugging Face Spaces compatibility)
600
- print("Starting Gradio interface on port 7860...")
601
- run_gradio()
 
1
  import os
2
+ import logging
3
+ import asyncio
4
+ from typing import Optional, Dict, Any, List
5
+ from datetime import datetime
6
+ import json
7
  import time
8
+ from pathlib import Path
9
+
10
+ from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.staticfiles import StaticFiles
13
+ from fastapi.responses import HTMLResponse, JSONResponse
14
+ from pydantic import BaseModel, Field
15
  import uvicorn
16
+
17
+ from langchain.document_loaders import TextLoader
 
 
 
 
 
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain.vectorstores import FAISS
20
  from langchain.chains import RetrievalQA
21
  from langchain.prompts import PromptTemplate
22
  from langchain.callbacks.base import BaseCallbackHandler
 
24
  import tiktoken
25
 
26
  # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
  logger = logging.getLogger(__name__)
32
 
33
+ # Initialize FastAPI app
34
+ app = FastAPI(
35
+ title="Maize Crop RAG System",
36
+ description="AI-powered Q&A system for maize agriculture",
37
+ version="1.0.0"
38
+ )
39
+
40
+ # Configure CORS
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
 
49
+ # Global variables for the RAG system
50
+ vector_store = None
51
+ qa_chain = None
52
+ token_callback_handler = None
53
+ is_initialized = False
54
+
55
+ # Configuration
56
+ class Config:
57
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
58
+ CHUNK_SIZE = 800
59
+ CHUNK_OVERLAP = 100
60
+ MAX_RETRIES = 3
61
+ RATE_LIMIT_DELAY = 1.0
62
+ MODEL_NAME = "gemini-1.5-flash"
63
+ EMBEDDING_MODEL = "models/embedding-001"
64
+ TEMPERATURE = 0.5
65
+ MAX_OUTPUT_TOKENS = 512
66
+ RETRIEVER_K = 5
67
+ INDEX_PATH = "faiss_maize_index"
68
+ DATA_PATH = "data/maize_data.txt"
69
+
70
+ config = Config()
71
+
72
+ # Request/Response Models
73
+ class QueryRequest(BaseModel):
74
+ query: str = Field(..., min_length=1, max_length=500)
75
+
76
+ class QueryResponse(BaseModel):
77
+ answer: str
78
+ sources: List[str] = []
79
+ token_usage: Dict[str, int] = {}
80
+ processing_time: float
81
+ timestamp: str
82
+
83
+ class SystemStatus(BaseModel):
84
+ status: str
85
+ is_initialized: bool
86
+ model_name: str
87
+ embedding_model: str
88
+ vector_store_ready: bool
89
+ total_chunks: int = 0
90
+ api_key_configured: bool
91
+
92
+ class InitializeRequest(BaseModel):
93
+ api_key: str = Field(..., min_length=1)
94
+
95
+ # Token counting utilities
96
  try:
97
  tokenizer = tiktoken.get_encoding("cl100k_base")
98
  except:
99
+ logger.warning("Tiktoken encoder not found. Using basic split().")
100
  tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})()
101
 
102
+ def estimate_tokens(text: str) -> int:
103
  """Estimates token count for a given text."""
104
  return len(tokenizer.encode(text))
105
 
106
+ # Custom Callback Handler
107
  class TokenUsageCallbackHandler(BaseCallbackHandler):
108
  """Callback handler to track token usage in LLM calls."""
109
+
110
  def __init__(self):
111
  super().__init__()
112
+ self.reset()
113
+
114
+ def reset(self):
115
  self.total_prompt_tokens = 0
116
  self.total_completion_tokens = 0
117
  self.total_llm_calls = 0
118
+ self.last_call_tokens = {}
119
+
120
  def on_llm_end(self, response, **kwargs):
121
  """Collect token usage from the LLM response."""
122
  self.total_llm_calls += 1
 
129
 
130
  self.total_prompt_tokens += prompt_tokens
131
  self.total_completion_tokens += completion_tokens
132
+
133
+ self.last_call_tokens = {
134
+ "prompt_tokens": prompt_tokens,
135
+ "completion_tokens": completion_tokens,
136
+ "total_tokens": prompt_tokens + completion_tokens
137
+ }
138
+
139
+ logger.info(f"Token usage - Prompt: {prompt_tokens}, Completion: {completion_tokens}")
140
+
141
+ def get_last_call_usage(self):
142
+ return self.last_call_tokens
143
+
144
+ def get_total_usage(self):
145
  return {
146
  "total_prompt_tokens": self.total_prompt_tokens,
147
  "total_completion_tokens": self.total_completion_tokens,
148
+ "total_tokens": self.total_prompt_tokens + self.total_completion_tokens,
149
+ "total_calls": self.total_llm_calls
150
  }
151
 
152
+ # RAG System Functions
153
+ async def initialize_rag_system(api_key: str = None):
154
+ """Initialize or reinitialize the RAG system."""
155
+ global vector_store, qa_chain, token_callback_handler, is_initialized, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  try:
158
+ # Use provided API key or environment variable
159
+ if api_key:
160
+ config.GOOGLE_API_KEY = api_key
161
+ os.environ["GOOGLE_API_KEY"] = api_key
162
+ elif not config.GOOGLE_API_KEY:
163
+ raise ValueError("Google API key not provided")
 
 
 
164
 
165
+ logger.info("Initializing RAG system...")
 
 
 
 
 
 
 
 
166
 
167
+ # Initialize token callback handler
168
+ token_callback_handler = TokenUsageCallbackHandler()
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Load and split documents
171
+ if not os.path.exists(config.DATA_PATH):
172
+ raise FileNotFoundError(f"Data file not found: {config.DATA_PATH}")
173
+
174
+ loader = TextLoader(config.DATA_PATH)
175
  documents = loader.load()
176
+
177
  text_splitter = RecursiveCharacterTextSplitter(
178
+ chunk_size=config.CHUNK_SIZE,
179
+ chunk_overlap=config.CHUNK_OVERLAP
180
  )
181
  chunks = text_splitter.split_documents(documents)
182
+ logger.info(f"Document split into {len(chunks)} chunks")
183
 
184
+ # Initialize embeddings
185
+ embeddings = GoogleGenerativeAIEmbeddings(
186
+ model=config.EMBEDDING_MODEL,
187
+ google_api_key=config.GOOGLE_API_KEY
188
+ )
189
 
190
+ # Create or load FAISS index
191
+ if os.path.exists(config.INDEX_PATH):
192
+ vector_store = FAISS.load_local(
193
+ config.INDEX_PATH,
194
+ embeddings,
195
+ allow_dangerous_deserialization=True
196
+ )
197
+ logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'")
198
+ else:
199
+ vector_store = FAISS.from_documents(chunks, embeddings)
200
+ vector_store.save_local(config.INDEX_PATH)
201
+ logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'")
202
+
203
+ # Initialize LLM
204
+ llm = ChatGoogleGenerativeAI(
205
+ model=config.MODEL_NAME,
206
+ google_api_key=config.GOOGLE_API_KEY,
207
+ temperature=config.TEMPERATURE,
208
+ max_tokens=config.MAX_OUTPUT_TOKENS,
209
+ callbacks=[token_callback_handler]
210
+ )
211
 
212
  # Create prompt template
213
  prompt_template = PromptTemplate(
214
  input_variables=["context", "question"],
215
  template="""
216
+ You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully.
217
+ If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question."
218
 
219
  Context:
220
  {context}
 
225
  )
226
 
227
  # Set up QA chain
228
+ qa_chain = RetrievalQA.from_chain_type(
229
  llm=llm,
230
  chain_type="stuff",
231
+ retriever=vector_store.as_retriever(search_kwargs={"k": config.RETRIEVER_K}),
232
  chain_type_kwargs={"prompt": prompt_template},
233
+ callbacks=[token_callback_handler],
234
  return_source_documents=True
235
  )
236
 
237
+ is_initialized = True
238
+ logger.info("RAG system initialized successfully")
239
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  except Exception as e:
242
+ logger.error(f"Failed to initialize RAG system: {str(e)}")
243
+ is_initialized = False
244
+ raise
245
+
246
+ # API Endpoints
247
+ @app.on_event("startup")
248
+ async def startup_event():
249
+ """Initialize the system on startup if API key is available."""
250
+ if config.GOOGLE_API_KEY:
251
+ try:
252
+ await initialize_rag_system()
253
+ except Exception as e:
254
+ logger.warning(f"Could not initialize on startup: {str(e)}")
255
+
256
+ @app.get("/", response_class=HTMLResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  async def root():
258
+ """Serve the main HTML page."""
259
+ with open("static/index.html", "r") as f:
260
+ return f.read()
261
+
262
+ @app.get("/api/status", response_model=SystemStatus)
263
+ async def get_status():
264
+ """Get system status."""
265
+ return SystemStatus(
266
+ status="ready" if is_initialized else "not_initialized",
267
+ is_initialized=is_initialized,
268
+ model_name=config.MODEL_NAME,
269
+ embedding_model=config.EMBEDDING_MODEL,
270
+ vector_store_ready=vector_store is not None,
271
+ total_chunks=len(vector_store.docstore._dict) if vector_store else 0,
272
+ api_key_configured=bool(config.GOOGLE_API_KEY)
273
+ )
274
 
275
+ @app.post("/api/initialize", response_model=Dict[str, Any])
276
  async def initialize_system(request: InitializeRequest):
277
+ """Initialize the RAG system with provided API key."""
278
  try:
279
+ await initialize_rag_system(request.api_key)
280
+ return {
281
+ "success": True,
282
+ "message": "System initialized successfully"
283
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  except Exception as e:
 
285
  raise HTTPException(status_code=500, detail=str(e))
286
 
287
+ @app.post("/api/query", response_model=QueryResponse)
288
+ async def process_query(request: QueryRequest):
289
+ """Process a query and return the answer."""
290
+ if not is_initialized:
291
+ raise HTTPException(
292
+ status_code=503,
293
+ detail="System not initialized. Please provide API key."
294
+ )
295
+
296
  try:
297
+ start_time = time.time()
 
298
 
299
+ # Reset token counter for this query
300
+ if token_callback_handler:
301
+ token_callback_handler.last_call_tokens = {}
 
302
 
303
+ # Process query with retry logic
304
+ for attempt in range(config.MAX_RETRIES):
305
+ try:
306
+ result = qa_chain({"query": request.query})
307
+ break
308
+ except Exception as e:
309
+ if attempt == config.MAX_RETRIES - 1:
310
+ raise
311
+ await asyncio.sleep(config.RATE_LIMIT_DELAY * (attempt + 1))
312
 
313
+ processing_time = time.time() - start_time
 
314
 
315
+ # Extract sources
316
+ sources = []
317
+ if 'source_documents' in result:
318
+ sources = [doc.page_content[:200] + "..."
319
+ for doc in result['source_documents'][:3]]
320
+
321
+ # Get token usage
322
+ token_usage = {}
323
+ if token_callback_handler:
324
+ token_usage = token_callback_handler.get_last_call_usage()
325
 
326
  return QueryResponse(
 
327
  answer=result['result'],
328
+ sources=sources,
329
+ token_usage=token_usage,
330
+ processing_time=round(processing_time, 2),
331
+ timestamp=datetime.now().isoformat()
332
  )
333
+
334
  except Exception as e:
335
+ logger.error(f"Error processing query: {str(e)}")
336
  raise HTTPException(status_code=500, detail=str(e))
337
 
338
+ @app.get("/api/token-stats", response_model=Dict[str, Any])
339
+ async def get_token_stats():
340
+ """Get token usage statistics."""
341
+ if not token_callback_handler:
342
+ return {"message": "No token statistics available"}
343
 
344
+ return token_callback_handler.get_total_usage()
 
 
 
 
 
 
345
 
346
+ @app.post("/api/upload-document")
347
+ async def upload_document(file: UploadFile = File(...)):
348
+ """Upload a new document to replace the existing one."""
 
 
 
 
 
 
 
 
 
349
  try:
350
+ # Save uploaded file
 
 
 
351
  content = await file.read()
352
+ with open(config.DATA_PATH, "wb") as f:
353
+ f.write(content)
354
+
355
+ # Reinitialize the system with new data
356
+ if config.GOOGLE_API_KEY:
357
+ # Remove old index to force recreation
358
+ if os.path.exists(config.INDEX_PATH):
359
+ import shutil
360
+ shutil.rmtree(config.INDEX_PATH)
361
+
362
+ await initialize_rag_system()
363
+ return {"success": True, "message": "Document uploaded and system reinitialized"}
364
  else:
365
+ return {"success": True, "message": "Document uploaded. Please initialize the system."}
366
+
367
  except Exception as e:
 
368
  raise HTTPException(status_code=500, detail=str(e))
369
 
370
+ # Mount static files
371
+ app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  if __name__ == "__main__":
374
+ uvicorn.run(app, host="0.0.0.0", port=7860)