Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import asyncio | |
| from typing import Optional, Dict, Any, List | |
| from datetime import datetime | |
| import json | |
| import time | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
| import tiktoken | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Maize Crop RAG System", | |
| description="AI-powered Q&A system for maize agriculture", | |
| version="1.0.0" | |
| ) | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for the RAG system | |
| vector_store = None | |
| qa_chain = None | |
| token_callback_handler = None | |
| is_initialized = False | |
| # Configuration | |
| class Config: | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "") | |
| CHUNK_SIZE = 800 | |
| CHUNK_OVERLAP = 100 | |
| MAX_RETRIES = 3 | |
| RATE_LIMIT_DELAY = 1.0 | |
| MODEL_NAME = "gemma-3-27b-it" | |
| EMBEDDING_MODEL = "models/embedding-001" | |
| TEMPERATURE = 0.5 | |
| MAX_OUTPUT_TOKENS = 512 | |
| RETRIEVER_K = 5 | |
| INDEX_PATH = "faiss_maize_index" | |
| DATA_PATH = "data/maize_data.txt" | |
| config = Config() | |
| # Request/Response Models | |
| class QueryRequest(BaseModel): | |
| query: str = Field(..., min_length=1, max_length=500) | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| sources: List[str] = [] | |
| token_usage: Dict[str, int] = {} | |
| processing_time: float | |
| timestamp: str | |
| class SystemStatus(BaseModel): | |
| status: str | |
| is_initialized: bool | |
| model_name: str | |
| embedding_model: str | |
| vector_store_ready: bool | |
| total_chunks: int = 0 | |
| api_key_configured: bool | |
| class InitializeRequest(BaseModel): | |
| api_key: str = Field(..., min_length=1) | |
| # Token counting utilities | |
| try: | |
| tokenizer = tiktoken.get_encoding("cl100k_base") | |
| except: | |
| logger.warning("Tiktoken encoder not found. Using basic split().") | |
| tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})() | |
| def estimate_tokens(text: str) -> int: | |
| """Estimates token count for a given text.""" | |
| return len(tokenizer.encode(text)) | |
| # Custom Callback Handler | |
| class TokenUsageCallbackHandler(BaseCallbackHandler): | |
| """Callback handler to track token usage in LLM calls.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.reset() | |
| def reset(self): | |
| self.total_prompt_tokens = 0 | |
| self.total_completion_tokens = 0 | |
| self.total_llm_calls = 0 | |
| self.last_call_tokens = {} | |
| def on_llm_end(self, response, **kwargs): | |
| """Collect token usage from the LLM response.""" | |
| self.total_llm_calls += 1 | |
| llm_output = response.llm_output | |
| if llm_output and 'usage_metadata' in llm_output: | |
| usage = llm_output['usage_metadata'] | |
| prompt_tokens = usage.get('prompt_token_count', 0) | |
| completion_tokens = usage.get('candidates_token_count', 0) | |
| self.total_prompt_tokens += prompt_tokens | |
| self.total_completion_tokens += completion_tokens | |
| self.last_call_tokens = { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens | |
| } | |
| logger.info(f"Token usage - Prompt: {prompt_tokens}, Completion: {completion_tokens}") | |
| def get_last_call_usage(self): | |
| return self.last_call_tokens | |
| def get_total_usage(self): | |
| return { | |
| "total_prompt_tokens": self.total_prompt_tokens, | |
| "total_completion_tokens": self.total_completion_tokens, | |
| "total_tokens": self.total_prompt_tokens + self.total_completion_tokens, | |
| "total_calls": self.total_llm_calls | |
| } | |
| # RAG System Functions | |
| async def initialize_rag_system(api_key: str = None): | |
| """Initialize or reinitialize the RAG system.""" | |
| global vector_store, qa_chain, token_callback_handler, is_initialized, config | |
| try: | |
| # Use provided API key or environment variable | |
| if api_key: | |
| config.GOOGLE_API_KEY = api_key | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| elif not config.GOOGLE_API_KEY: | |
| raise ValueError("Google API key not provided") | |
| logger.info("Initializing RAG system...") | |
| # Initialize token callback handler | |
| token_callback_handler = TokenUsageCallbackHandler() | |
| # Load and split documents | |
| if not os.path.exists(config.DATA_PATH): | |
| raise FileNotFoundError(f"Data file not found: {config.DATA_PATH}") | |
| loader = TextLoader(config.DATA_PATH) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=config.CHUNK_SIZE, | |
| chunk_overlap=config.CHUNK_OVERLAP | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| logger.info(f"Document split into {len(chunks)} chunks") | |
| # Initialize embeddings | |
| embeddings = GoogleGenerativeAIEmbeddings( | |
| model=config.EMBEDDING_MODEL, | |
| google_api_key=config.GOOGLE_API_KEY | |
| ) | |
| # Create or load FAISS index | |
| if os.path.exists(config.INDEX_PATH): | |
| vector_store = FAISS.load_local( | |
| config.INDEX_PATH, | |
| embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'") | |
| else: | |
| vector_store = FAISS.from_documents(chunks, embeddings) | |
| vector_store.save_local(config.INDEX_PATH) | |
| logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'") | |
| # Initialize LLM | |
| llm = ChatGoogleGenerativeAI( | |
| model=config.MODEL_NAME, | |
| google_api_key=config.GOOGLE_API_KEY, | |
| temperature=config.TEMPERATURE, | |
| max_tokens=config.MAX_OUTPUT_TOKENS, | |
| callbacks=[token_callback_handler] | |
| ) | |
| # Create prompt template | |
| prompt_template = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=""" | |
| You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully. | |
| If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| ) | |
| # Set up QA chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vector_store.as_retriever(search_kwargs={"k": config.RETRIEVER_K}), | |
| chain_type_kwargs={"prompt": prompt_template}, | |
| callbacks=[token_callback_handler], | |
| return_source_documents=True | |
| ) | |
| is_initialized = True | |
| logger.info("RAG system initialized successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG system: {str(e)}") | |
| is_initialized = False | |
| raise | |
| # API Endpoints | |
| async def startup_event(): | |
| """Initialize the system on startup if API key is available.""" | |
| if config.GOOGLE_API_KEY: | |
| try: | |
| await initialize_rag_system() | |
| except Exception as e: | |
| logger.warning(f"Could not initialize on startup: {str(e)}") | |
| async def root(): | |
| """Serve the main HTML page.""" | |
| with open("static/index.html", "r") as f: | |
| return f.read() | |
| async def get_status(): | |
| """Get system status.""" | |
| return SystemStatus( | |
| status="ready" if is_initialized else "not_initialized", | |
| is_initialized=is_initialized, | |
| model_name=config.MODEL_NAME, | |
| embedding_model=config.EMBEDDING_MODEL, | |
| vector_store_ready=vector_store is not None, | |
| total_chunks=len(vector_store.docstore._dict) if vector_store else 0, | |
| api_key_configured=bool(config.GOOGLE_API_KEY) | |
| ) | |
| async def initialize_system(request: InitializeRequest): | |
| """Initialize the RAG system with provided API key.""" | |
| try: | |
| await initialize_rag_system(request.api_key) | |
| return { | |
| "success": True, | |
| "message": "System initialized successfully" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def process_query(request: QueryRequest): | |
| """Process a query and return the answer.""" | |
| if not is_initialized: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="System not initialized. Please provide API key." | |
| ) | |
| try: | |
| start_time = time.time() | |
| # Reset token counter for this query | |
| if token_callback_handler: | |
| token_callback_handler.last_call_tokens = {} | |
| # Process query with retry logic | |
| for attempt in range(config.MAX_RETRIES): | |
| try: | |
| result = qa_chain({"query": request.query}) | |
| break | |
| except Exception as e: | |
| if attempt == config.MAX_RETRIES - 1: | |
| raise | |
| await asyncio.sleep(config.RATE_LIMIT_DELAY * (attempt + 1)) | |
| processing_time = time.time() - start_time | |
| # Extract sources | |
| sources = [] | |
| if 'source_documents' in result: | |
| sources = [doc.page_content[:200] + "..." | |
| for doc in result['source_documents'][:3]] | |
| # Get token usage | |
| token_usage = {} | |
| if token_callback_handler: | |
| token_usage = token_callback_handler.get_last_call_usage() | |
| return QueryResponse( | |
| answer=result['result'], | |
| sources=sources, | |
| token_usage=token_usage, | |
| processing_time=round(processing_time, 2), | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_token_stats(): | |
| """Get token usage statistics.""" | |
| if not token_callback_handler: | |
| return {"message": "No token statistics available"} | |
| return token_callback_handler.get_total_usage() | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload a new document to replace the existing one.""" | |
| try: | |
| # Save uploaded file | |
| content = await file.read() | |
| with open(config.DATA_PATH, "wb") as f: | |
| f.write(content) | |
| # Reinitialize the system with new data | |
| if config.GOOGLE_API_KEY: | |
| # Remove old index to force recreation | |
| if os.path.exists(config.INDEX_PATH): | |
| import shutil | |
| shutil.rmtree(config.INDEX_PATH) | |
| await initialize_rag_system() | |
| return {"success": True, "message": "Document uploaded and system reinitialized"} | |
| else: | |
| return {"success": True, "message": "Document uploaded. Please initialize the system."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |