Agrosure_RAG / api.py
agentsay's picture
Update api.py
c13eb80 verified
import collections
from collections.abc import MutableMapping
collections.MutableMapping = MutableMapping # Patch for deprecated MutableMapping
import os
import shutil
import json
import logging
from contextlib import asynccontextmanager
from typing import Dict
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
import config # Ensure config.py has GROQ_API_KEY
# Set environment variable for Groq API key
os.environ["GROQ_API_KEY"] = config.GROQ_API_KEY
# Setup logging
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Global variables for RAG components
rag_chain = None
retriever = None
session_states: Dict[str, str] = {} # Store last_disease per session_id
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain, retriever
persist_directory = "/app/data/chroma_crop_rag"
# Clear existing ChromaDB collection
if os.path.exists(persist_directory):
try:
shutil.rmtree(persist_directory)
logger.debug("Cleared existing ChromaDB directory: %s", persist_directory)
except Exception as e:
logger.error("Error clearing ChromaDB directory: %s", str(e))
raise
# Load JSON QA Knowledge Base
try:
with open("crop_disease_qa.json", "r", encoding="utf-8") as f:
data = json.load(f)
logger.debug("JSON loaded, length: %d", len(data))
except Exception as e:
logger.error("Error loading JSON: %s", str(e))
raise
# Convert to Documents
documents = [
Document(page_content=item["answer"], metadata={"question": item["question"]})
for item in data
]
logger.debug("Documents created: %d", len(documents))
# Chunk Documents
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
docs = splitter.split_documents(documents)
logger.debug("Documents after splitting: %d", len(docs))
# Embedding + Vectorstore
try:
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
logger.debug("Embedding model initialized")
db = Chroma.from_documents(docs, embedding_model, persist_directory=persist_directory)
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 6})
logger.debug("ChromaDB initialized")
except Exception as e:
logger.error("ChromaDB/Embedding error: %s", str(e))
raise
# Groq LLM
try:
llm = init_chat_model(
"llama3-8b-8192",
model_provider="groq",
temperature=0.5
)
logger.debug("Groq LLM initialized")
except Exception as e:
logger.error("Groq LLM initialization error: %s", str(e))
raise
# Prompt Template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You're a friendly expert helping migrant workers with e-Shram queries. Answer in a warm, conversational tone, like chatting with a neighbor. Keep it clear, engaging, and avoid technical jargon. Use the provided context for accuracy. If a follow-up (e.g., 'how to register?', 'what benefits?'), assume it refers to the e-Shram portal unless specified. If the context lacks details, give a practical, general response with actionable tips. Keep answers under 100 words..
Context: {context}
Question: {question}
Answer:
"""
)
# RAG Chain
try:
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={"prompt": prompt_template}
)
logger.debug("RAG chain initialized")
except Exception as e:
logger.error("RAG chain initialization error: %s", str(e))
raise
yield # FastAPI is now running
# Initialize FastAPI with lifespan
app = FastAPI(title="Crop Health Assistant API", lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Pydantic request model
class QueryRequest(BaseModel):
query: str
session_id: str = "default"
# Query endpoint
@app.post("/query")
async def query_crop_health(request: QueryRequest):
global session_states
query = request.query
session_id = request.session_id
if query.lower() == "exit":
session_states.pop(session_id, None)
return JSONResponse(content={"message": "Session ended"})
# Handle follow-up queries
modified_query = query
last_disease = session_states.get(session_id)
if last_disease and query.lower() in [
"how to treat them?", "how to fix it?",
"how to manage it?", "what medicines should i use?"
]:
modified_query = f"What medicines or treatments for {last_disease}?"
try:
response = rag_chain.invoke({"query": modified_query})["result"]
# Simple heuristic to update last disease
if "blight" in query.lower() or "potato" in query.lower():
session_states[session_id] = "Early blight in Potato"
return JSONResponse(content={"question": query, "answer": response})
except Exception as e:
logger.error("RAG chain execution error for query '%s': %s", query, str(e))
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
# Session reset endpoint
@app.delete("/reset-session/{session_id}")
async def reset_session(session_id: str):
global session_states
session_states.pop(session_id, None)
return JSONResponse(content={"message": f"Session {session_id} reset"})
# Run FastAPI with Uvicorn
if __name__ == "__main__":
import uvicorn
print("Starting FastAPI server")
uvicorn.run(app, host="0.0.0.0", port=7860)