CROP-RAG-API / app.py
NitinBot001's picture
Update app.py
7600fa1 verified
raw
history blame
12.5 kB
import os
import logging
import asyncio
from typing import Optional, Dict, Any, List
from datetime import datetime
import json
import time
from pathlib import Path
from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel, Field
import uvicorn
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.callbacks.base import BaseCallbackHandler
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
import tiktoken
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Maize Crop RAG System",
description="AI-powered Q&A system for maize agriculture",
version="1.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for the RAG system
vector_store = None
qa_chain = None
token_callback_handler = None
is_initialized = False
# Configuration
class Config:
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
CHUNK_SIZE = 800
CHUNK_OVERLAP = 100
MAX_RETRIES = 3
RATE_LIMIT_DELAY = 1.0
MODEL_NAME = "gemma-3-27b-it"
EMBEDDING_MODEL = "models/embedding-001"
TEMPERATURE = 0.5
MAX_OUTPUT_TOKENS = 512
RETRIEVER_K = 5
INDEX_PATH = "faiss_maize_index"
DATA_PATH = "data/maize_data.txt"
config = Config()
# Request/Response Models
class QueryRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=500)
class QueryResponse(BaseModel):
answer: str
sources: List[str] = []
token_usage: Dict[str, int] = {}
processing_time: float
timestamp: str
class SystemStatus(BaseModel):
status: str
is_initialized: bool
model_name: str
embedding_model: str
vector_store_ready: bool
total_chunks: int = 0
api_key_configured: bool
class InitializeRequest(BaseModel):
api_key: str = Field(..., min_length=1)
# Token counting utilities
try:
tokenizer = tiktoken.get_encoding("cl100k_base")
except:
logger.warning("Tiktoken encoder not found. Using basic split().")
tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})()
def estimate_tokens(text: str) -> int:
"""Estimates token count for a given text."""
return len(tokenizer.encode(text))
# Custom Callback Handler
class TokenUsageCallbackHandler(BaseCallbackHandler):
"""Callback handler to track token usage in LLM calls."""
def __init__(self):
super().__init__()
self.reset()
def reset(self):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.total_llm_calls = 0
self.last_call_tokens = {}
def on_llm_end(self, response, **kwargs):
"""Collect token usage from the LLM response."""
self.total_llm_calls += 1
llm_output = response.llm_output
if llm_output and 'usage_metadata' in llm_output:
usage = llm_output['usage_metadata']
prompt_tokens = usage.get('prompt_token_count', 0)
completion_tokens = usage.get('candidates_token_count', 0)
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
self.last_call_tokens = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
logger.info(f"Token usage - Prompt: {prompt_tokens}, Completion: {completion_tokens}")
def get_last_call_usage(self):
return self.last_call_tokens
def get_total_usage(self):
return {
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
"total_tokens": self.total_prompt_tokens + self.total_completion_tokens,
"total_calls": self.total_llm_calls
}
# RAG System Functions
async def initialize_rag_system(api_key: str = None):
"""Initialize or reinitialize the RAG system."""
global vector_store, qa_chain, token_callback_handler, is_initialized, config
try:
# Use provided API key or environment variable
if api_key:
config.GOOGLE_API_KEY = api_key
os.environ["GOOGLE_API_KEY"] = api_key
elif not config.GOOGLE_API_KEY:
raise ValueError("Google API key not provided")
logger.info("Initializing RAG system...")
# Initialize token callback handler
token_callback_handler = TokenUsageCallbackHandler()
# Load and split documents
if not os.path.exists(config.DATA_PATH):
raise FileNotFoundError(f"Data file not found: {config.DATA_PATH}")
loader = TextLoader(config.DATA_PATH)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.CHUNK_SIZE,
chunk_overlap=config.CHUNK_OVERLAP
)
chunks = text_splitter.split_documents(documents)
logger.info(f"Document split into {len(chunks)} chunks")
# Initialize embeddings
embeddings = GoogleGenerativeAIEmbeddings(
model=config.EMBEDDING_MODEL,
google_api_key=config.GOOGLE_API_KEY
)
# Create or load FAISS index
if os.path.exists(config.INDEX_PATH):
vector_store = FAISS.load_local(
config.INDEX_PATH,
embeddings,
allow_dangerous_deserialization=True
)
logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'")
else:
vector_store = FAISS.from_documents(chunks, embeddings)
vector_store.save_local(config.INDEX_PATH)
logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'")
# Initialize LLM
llm = ChatGoogleGenerativeAI(
model=config.MODEL_NAME,
google_api_key=config.GOOGLE_API_KEY,
temperature=config.TEMPERATURE,
max_tokens=config.MAX_OUTPUT_TOKENS,
callbacks=[token_callback_handler]
)
# Create prompt template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully.
If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question."
Context:
{context}
Question: {question}
Answer:"""
)
# Set up QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": config.RETRIEVER_K}),
chain_type_kwargs={"prompt": prompt_template},
callbacks=[token_callback_handler],
return_source_documents=True
)
is_initialized = True
logger.info("RAG system initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize RAG system: {str(e)}")
is_initialized = False
raise
# API Endpoints
@app.on_event("startup")
async def startup_event():
"""Initialize the system on startup if API key is available."""
if config.GOOGLE_API_KEY:
try:
await initialize_rag_system()
except Exception as e:
logger.warning(f"Could not initialize on startup: {str(e)}")
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the main HTML page."""
with open("static/index.html", "r") as f:
return f.read()
@app.get("/api/status", response_model=SystemStatus)
async def get_status():
"""Get system status."""
return SystemStatus(
status="ready" if is_initialized else "not_initialized",
is_initialized=is_initialized,
model_name=config.MODEL_NAME,
embedding_model=config.EMBEDDING_MODEL,
vector_store_ready=vector_store is not None,
total_chunks=len(vector_store.docstore._dict) if vector_store else 0,
api_key_configured=bool(config.GOOGLE_API_KEY)
)
@app.post("/api/initialize", response_model=Dict[str, Any])
async def initialize_system(request: InitializeRequest):
"""Initialize the RAG system with provided API key."""
try:
await initialize_rag_system(request.api_key)
return {
"success": True,
"message": "System initialized successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
"""Process a query and return the answer."""
if not is_initialized:
raise HTTPException(
status_code=503,
detail="System not initialized. Please provide API key."
)
try:
start_time = time.time()
# Reset token counter for this query
if token_callback_handler:
token_callback_handler.last_call_tokens = {}
# Process query with retry logic
for attempt in range(config.MAX_RETRIES):
try:
result = qa_chain({"query": request.query})
break
except Exception as e:
if attempt == config.MAX_RETRIES - 1:
raise
await asyncio.sleep(config.RATE_LIMIT_DELAY * (attempt + 1))
processing_time = time.time() - start_time
# Extract sources
sources = []
if 'source_documents' in result:
sources = [doc.page_content[:200] + "..."
for doc in result['source_documents'][:3]]
# Get token usage
token_usage = {}
if token_callback_handler:
token_usage = token_callback_handler.get_last_call_usage()
return QueryResponse(
answer=result['result'],
sources=sources,
token_usage=token_usage,
processing_time=round(processing_time, 2),
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/token-stats", response_model=Dict[str, Any])
async def get_token_stats():
"""Get token usage statistics."""
if not token_callback_handler:
return {"message": "No token statistics available"}
return token_callback_handler.get_total_usage()
@app.post("/api/upload-document")
async def upload_document(file: UploadFile = File(...)):
"""Upload a new document to replace the existing one."""
try:
# Save uploaded file
content = await file.read()
with open(config.DATA_PATH, "wb") as f:
f.write(content)
# Reinitialize the system with new data
if config.GOOGLE_API_KEY:
# Remove old index to force recreation
if os.path.exists(config.INDEX_PATH):
import shutil
shutil.rmtree(config.INDEX_PATH)
await initialize_rag_system()
return {"success": True, "message": "Document uploaded and system reinitialized"}
else:
return {"success": True, "message": "Document uploaded. Please initialize the system."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)