import os import logging import asyncio from typing import Optional, Dict, Any, List from datetime import datetime import json import time from pathlib import Path import random from fastapi import FastAPI, HTTPException, File, UploadFile, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel, Field import uvicorn from langchain_community.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.callbacks.base import BaseCallbackHandler from langchain_openai import ChatOpenAI, OpenAIEmbeddings import tiktoken # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Maize Crop RAG System", description="AI-powered Q&A system for maize agriculture", version="1.0.0" ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for the RAG system vector_store = None qa_chain = None token_callback_handler = None is_initialized = False # Configuration class Config: # API Keys - separate for each service OPENAI_LLM_API_KEY = os.getenv("OPENAI_LLM_API_KEY", os.getenv("OPENAI_API_KEY", "")) OPENAI_EMBEDDING_API_KEY = os.getenv("OPENAI_EMBEDDING_API_KEY", os.getenv("OPENAI_API_KEY", "")) # Base URLs - separate for each service OPENAI_LLM_BASE_URL = os.getenv("OPENAI_LLM_BASE_URL", os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")) OPENAI_EMBEDDING_BASE_URL = os.getenv("OPENAI_EMBEDDING_BASE_URL", os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")) # Model Configuration LLM_MODEL = os.getenv("LLM_MODEL", "gpt-3.5-turbo") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002") # Document Processing CHUNK_SIZE = 1000 CHUNK_OVERLAP = 200 # Rate Limiting MAX_RETRIES = 5 RATE_LIMIT_DELAY = 2.0 EMBEDDING_BATCH_SIZE = 10 EMBEDDING_DELAY = 1.0 # Model Parameters TEMPERATURE = 0.8 MAX_OUTPUT_TOKENS = 120000 RETRIEVER_K = 20 # Paths INDEX_PATH = "faiss_maize_index" DATA_PATH = "data/maize_data.txt" config = Config() # Request/Response Models class QueryRequest(BaseModel): query: str = Field(..., min_length=1, max_length=120000) class QueryResponse(BaseModel): answer: str sources: List[str] = [] token_usage: Dict[str, int] = {} processing_time: float timestamp: str class SystemStatus(BaseModel): status: str is_initialized: bool model_name: str embedding_model: str llm_base_url: str embedding_base_url: str vector_store_ready: bool total_chunks: int = 0 llm_api_key_configured: bool embedding_api_key_configured: bool class InitializeRequest(BaseModel): llm_api_key: Optional[str] = Field(default=None, description="API key for LLM service") embedding_api_key: Optional[str] = Field(default=None, description="API key for embedding service") # Backward compatibility - if provided, will be used for both services if individual keys not specified api_key: Optional[str] = Field(default=None, description="Fallback API key for both services") llm_base_url: Optional[str] = Field(default=None, description="Base URL for LLM/text generation API") embedding_base_url: Optional[str] = Field(default=None, description="Base URL for embedding API") llm_model: Optional[str] = Field(default=None, description="LLM model name") embedding_model: Optional[str] = Field(default=None, description="Embedding model name") # Token counting utilities try: tokenizer = tiktoken.get_encoding("cl100k_base") except: logger.warning("Tiktoken encoder not found. Using basic split().") tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})() def estimate_tokens(text: str) -> int: """Estimates token count for a given text.""" try: return len(tokenizer.encode(text)) except: return len(text.split()) * 1.3 # Rough estimate # Rate limiting helper functions async def rate_limited_embedding_creation(chunks, embeddings): """Create embeddings with rate limiting to avoid API limits.""" logger.info(f"Creating embeddings for {len(chunks)} chunks with rate limiting...") # Process chunks in smaller batches batch_size = config.EMBEDDING_BATCH_SIZE all_embeddings = [] for i in range(0, len(chunks), batch_size): batch = chunks[i:i + batch_size] logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} ({len(batch)} chunks)") retry_count = 0 max_retries = 5 while retry_count < max_retries: try: # Create vector store for this batch if i == 0: # First batch - create new vector store vector_store_batch = FAISS.from_documents(batch, embeddings) all_embeddings.append(vector_store_batch) else: # Subsequent batches - merge with existing vector_store_batch = FAISS.from_documents(batch, embeddings) all_embeddings.append(vector_store_batch) logger.info(f"Successfully processed batch {i//batch_size + 1}") break except Exception as e: retry_count += 1 delay = config.EMBEDDING_DELAY * (2 ** retry_count) + random.uniform(0, 1) logger.warning(f"Batch {i//batch_size + 1} failed (attempt {retry_count}): {str(e)}") if "rate limit" in str(e).lower() or "429" in str(e): logger.info(f"Rate limit detected. Waiting {delay:.2f} seconds before retry...") else: logger.info(f"API error detected. Waiting {delay:.2f} seconds before retry...") await asyncio.sleep(delay) if retry_count >= max_retries: raise Exception(f"Failed to process batch after {max_retries} attempts: {str(e)}") # Delay between batches to respect rate limits if i + batch_size < len(chunks): delay = config.EMBEDDING_DELAY + random.uniform(0.2, 0.5) logger.info(f"Waiting {delay:.2f} seconds before next batch...") await asyncio.sleep(delay) # Merge all vector stores logger.info("Merging all vector store batches...") final_vector_store = all_embeddings[0] for i in range(1, len(all_embeddings)): final_vector_store.merge_from(all_embeddings[i]) logger.info(f"Merged batch {i + 1}/{len(all_embeddings)}") logger.info("Successfully created and merged all embeddings") return final_vector_store # Custom Callback Handler for OpenAI class TokenUsageCallbackHandler(BaseCallbackHandler): """Callback handler to track token usage in OpenAI calls.""" def __init__(self): super().__init__() self.reset() def reset(self): self.total_prompt_tokens = 0 self.total_completion_tokens = 0 self.total_llm_calls = 0 self.last_call_tokens = {} def on_llm_end(self, response, **kwargs): """Collect token usage from the OpenAI response.""" self.total_llm_calls += 1 llm_output = response.llm_output # OpenAI token usage structure if llm_output and 'token_usage' in llm_output: usage = llm_output['token_usage'] prompt_tokens = usage.get('prompt_tokens', 0) completion_tokens = usage.get('completion_tokens', 0) self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens self.last_call_tokens = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens } logger.info(f"Token usage - Prompt: {prompt_tokens}, Completion: {completion_tokens}") else: # Fallback token estimation if usage not available logger.info("Token usage not available from API response") def get_last_call_usage(self): return self.last_call_tokens def get_total_usage(self): return { "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, "total_tokens": self.total_prompt_tokens + self.total_completion_tokens, "total_calls": self.total_llm_calls } # RAG System Functions async def initialize_rag_system( llm_api_key: str = None, embedding_api_key: str = None, api_key: str = None, # Fallback for backward compatibility llm_base_url: str = None, embedding_base_url: str = None, llm_model: str = None, embedding_model: str = None ): """Initialize or reinitialize the RAG system with separate OpenAI compatible APIs and keys.""" global vector_store, qa_chain, token_callback_handler, is_initialized, config try: # Handle API key configuration with fallback logic if llm_api_key: config.OPENAI_LLM_API_KEY = llm_api_key elif api_key: config.OPENAI_LLM_API_KEY = api_key elif not config.OPENAI_LLM_API_KEY: raise ValueError("LLM API key not provided") if embedding_api_key: config.OPENAI_EMBEDDING_API_KEY = embedding_api_key elif api_key: config.OPENAI_EMBEDDING_API_KEY = api_key elif not config.OPENAI_EMBEDDING_API_KEY: raise ValueError("Embedding API key not provided") # Update base URLs if llm_base_url: config.OPENAI_LLM_BASE_URL = llm_base_url if embedding_base_url: config.OPENAI_EMBEDDING_BASE_URL = embedding_base_url if llm_model: config.LLM_MODEL = llm_model if embedding_model: config.EMBEDDING_MODEL = embedding_model logger.info(f"Initializing RAG system with:") logger.info(f" - LLM Base URL: {config.OPENAI_LLM_BASE_URL}") logger.info(f" - LLM API Key: {'*' * (len(config.OPENAI_LLM_API_KEY) - 8) + config.OPENAI_LLM_API_KEY[-8:] if len(config.OPENAI_LLM_API_KEY) > 8 else '*' * len(config.OPENAI_LLM_API_KEY)}") logger.info(f" - Embedding Base URL: {config.OPENAI_EMBEDDING_BASE_URL}") logger.info(f" - Embedding API Key: {'*' * (len(config.OPENAI_EMBEDDING_API_KEY) - 8) + config.OPENAI_EMBEDDING_API_KEY[-8:] if len(config.OPENAI_EMBEDDING_API_KEY) > 8 else '*' * len(config.OPENAI_EMBEDDING_API_KEY)}") logger.info(f" - LLM Model: {config.LLM_MODEL}") logger.info(f" - Embedding Model: {config.EMBEDDING_MODEL}") # Initialize token callback handler token_callback_handler = TokenUsageCallbackHandler() # Load and split documents if not os.path.exists(config.DATA_PATH): raise FileNotFoundError(f"Data file not found: {config.DATA_PATH}") loader = TextLoader(config.DATA_PATH, encoding='utf-8') documents = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=config.CHUNK_SIZE, chunk_overlap=config.CHUNK_OVERLAP, separators=["\n\n", "\n", " ", ""] ) chunks = text_splitter.split_documents(documents) logger.info(f"Document split into {len(chunks)} chunks") # Check if we have too many chunks that might cause rate limiting if len(chunks) > 200: logger.warning(f"Large number of chunks ({len(chunks)}). Consider increasing chunk_size to reduce API calls.") # Initialize OpenAI embeddings with separate API key and base URL embeddings = OpenAIEmbeddings( model=config.EMBEDDING_MODEL, openai_api_key=config.OPENAI_EMBEDDING_API_KEY, # Use embedding-specific API key openai_api_base=config.OPENAI_EMBEDDING_BASE_URL, chunk_size=1000 ) # Test embedding connection try: test_embedding = await asyncio.get_event_loop().run_in_executor( None, embeddings.embed_query, "test connection" ) logger.info("Successfully connected to embedding API") except Exception as e: logger.error(f"Failed to connect to embedding API: {str(e)}") raise # Create or load FAISS index with rate limiting if os.path.exists(config.INDEX_PATH): try: vector_store = FAISS.load_local( config.INDEX_PATH, embeddings, allow_dangerous_deserialization=True ) logger.info(f"Loaded existing FAISS index from '{config.INDEX_PATH}'") except Exception as e: logger.warning(f"Failed to load existing index: {str(e)}") logger.info("Creating new index...") vector_store = await rate_limited_embedding_creation(chunks, embeddings) vector_store.save_local(config.INDEX_PATH) logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'") else: vector_store = await rate_limited_embedding_creation(chunks, embeddings) vector_store.save_local(config.INDEX_PATH) logger.info(f"Created new FAISS index at '{config.INDEX_PATH}'") # Initialize OpenAI LLM with separate API key and base URL llm = ChatOpenAI( model_name=config.LLM_MODEL, openai_api_key=config.OPENAI_LLM_API_KEY, # Use LLM-specific API key openai_api_base=config.OPENAI_LLM_BASE_URL, temperature=config.TEMPERATURE, max_tokens=config.MAX_OUTPUT_TOKENS, callbacks=[token_callback_handler], request_timeout=30 ) # Test LLM connection try: test_response = llm.invoke("Test connection") logger.info("Successfully connected to LLM API") except Exception as e: logger.error(f"Failed to connect to LLM API: {str(e)}") raise # Create prompt template prompt_template = PromptTemplate( input_variables=["context", "question"], template="""You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully. If the query asks for personal information of any person, do not provide it and instead explain that you cannot share personal information. Provide clear, concise answers in easy-to-understand language. If the context doesn't contain enough information to answer the question completely, say so. Context: {context} Question: {question} Answer:""" ) # Set up QA chain qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=vector_store.as_retriever( search_type="similarity", search_kwargs={"k": config.RETRIEVER_K} ), chain_type_kwargs={"prompt": prompt_template}, callbacks=[token_callback_handler], return_source_documents=True ) is_initialized = True logger.info("RAG system initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize RAG system: {str(e)}") is_initialized = False raise # API Endpoints @app.on_event("startup") async def startup_event(): """Initialize the system on startup if API keys are available.""" if config.OPENAI_LLM_API_KEY and config.OPENAI_EMBEDDING_API_KEY: try: await initialize_rag_system() except Exception as e: logger.warning(f"Could not initialize on startup: {str(e)}") @app.get("/", response_class=HTMLResponse) async def root(): """Serve the main HTML page.""" try: with open("static/index.html", "r") as f: return f.read() except FileNotFoundError: return """
API is running. Please use the API endpoints or add static/index.html for web interface.