File size: 6,488 Bytes
0694d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c13eb80
0694d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c71fef6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import collections
from collections.abc import MutableMapping
collections.MutableMapping = MutableMapping  # Patch for deprecated MutableMapping

import os
import shutil
import json
import logging
from contextlib import asynccontextmanager
from typing import Dict

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA

import config  # Ensure config.py has GROQ_API_KEY

# Set environment variable for Groq API key
os.environ["GROQ_API_KEY"] = config.GROQ_API_KEY

# Setup logging
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Global variables for RAG components
rag_chain = None
retriever = None
session_states: Dict[str, str] = {}  # Store last_disease per session_id


@asynccontextmanager
async def lifespan(app: FastAPI):
    global rag_chain, retriever
    persist_directory = "/app/data/chroma_crop_rag"


    # Clear existing ChromaDB collection
    if os.path.exists(persist_directory):
        try:
            shutil.rmtree(persist_directory)
            logger.debug("Cleared existing ChromaDB directory: %s", persist_directory)
        except Exception as e:
            logger.error("Error clearing ChromaDB directory: %s", str(e))
            raise

    # Load JSON QA Knowledge Base
    try:
        with open("crop_disease_qa.json", "r", encoding="utf-8") as f:
            data = json.load(f)
        logger.debug("JSON loaded, length: %d", len(data))
    except Exception as e:
        logger.error("Error loading JSON: %s", str(e))
        raise

    # Convert to Documents
    documents = [
        Document(page_content=item["answer"], metadata={"question": item["question"]})
        for item in data
    ]
    logger.debug("Documents created: %d", len(documents))

    # Chunk Documents
    splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
    docs = splitter.split_documents(documents)
    logger.debug("Documents after splitting: %d", len(docs))

    # Embedding + Vectorstore
    try:
        embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
        logger.debug("Embedding model initialized")
        db = Chroma.from_documents(docs, embedding_model, persist_directory=persist_directory)
        retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 6})
        logger.debug("ChromaDB initialized")
    except Exception as e:
        logger.error("ChromaDB/Embedding error: %s", str(e))
        raise

    # Groq LLM
    try:
        llm = init_chat_model(
            "llama3-8b-8192",
            model_provider="groq",
            temperature=0.5
        )
        logger.debug("Groq LLM initialized")
    except Exception as e:
        logger.error("Groq LLM initialization error: %s", str(e))
        raise

    # Prompt Template
    prompt_template = PromptTemplate(
        input_variables=["context", "question"],
        template="""
        You're a friendly expert helping migrant workers with e-Shram queries. Answer in a warm, conversational tone, like chatting with a neighbor. Keep it clear, engaging, and avoid technical jargon. Use the provided context for accuracy. If a follow-up (e.g., 'how to register?', 'what benefits?'), assume it refers to the e-Shram portal unless specified. If the context lacks details, give a practical, general response with actionable tips. Keep answers under 100 words..
        Context: {context}

        Question: {question}

        Answer:
        """
    )

    # RAG Chain
    try:
        rag_chain = RetrievalQA.from_chain_type(
            llm=llm,
            retriever=retriever,
            chain_type="stuff",
            chain_type_kwargs={"prompt": prompt_template}
        )
        logger.debug("RAG chain initialized")
    except Exception as e:
        logger.error("RAG chain initialization error: %s", str(e))
        raise

    yield  # FastAPI is now running


# Initialize FastAPI with lifespan
app = FastAPI(title="Crop Health Assistant API", lifespan=lifespan)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)


# Pydantic request model
class QueryRequest(BaseModel):
    query: str
    session_id: str = "default"


# Query endpoint
@app.post("/query")
async def query_crop_health(request: QueryRequest):
    global session_states
    query = request.query
    session_id = request.session_id

    if query.lower() == "exit":
        session_states.pop(session_id, None)
        return JSONResponse(content={"message": "Session ended"})

    # Handle follow-up queries
    modified_query = query
    last_disease = session_states.get(session_id)
    if last_disease and query.lower() in [
        "how to treat them?", "how to fix it?",
        "how to manage it?", "what medicines should i use?"
    ]:
        modified_query = f"What medicines or treatments for {last_disease}?"

    try:
        response = rag_chain.invoke({"query": modified_query})["result"]
        # Simple heuristic to update last disease
        if "blight" in query.lower() or "potato" in query.lower():
            session_states[session_id] = "Early blight in Potato"

        return JSONResponse(content={"question": query, "answer": response})
    except Exception as e:
        logger.error("RAG chain execution error for query '%s': %s", query, str(e))
        raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")


# Session reset endpoint
@app.delete("/reset-session/{session_id}")
async def reset_session(session_id: str):
    global session_states
    session_states.pop(session_id, None)
    return JSONResponse(content={"message": f"Session {session_id} reset"})


# Run FastAPI with Uvicorn
if __name__ == "__main__":
    import uvicorn
    print("Starting FastAPI server")
    uvicorn.run(app, host="0.0.0.0", port=7860)