agentsay commited on
Commit
0694d44
·
verified ·
1 Parent(s): 72800b1

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +197 -196
api.py CHANGED
@@ -1,197 +1,198 @@
1
- import collections
2
- from collections.abc import MutableMapping
3
- collections.MutableMapping = MutableMapping # Patch for deprecated MutableMapping
4
-
5
- import os
6
- import shutil
7
- import json
8
- import logging
9
- from contextlib import asynccontextmanager
10
- from typing import Dict
11
-
12
- from fastapi import FastAPI, HTTPException
13
- from fastapi.responses import JSONResponse
14
- from fastapi.middleware.cors import CORSMiddleware
15
- from pydantic import BaseModel
16
-
17
- from langchain.chat_models import init_chat_model
18
- from langchain_core.documents import Document
19
- from langchain_core.prompts import PromptTemplate
20
- from langchain_text_splitters import RecursiveCharacterTextSplitter
21
- from langchain_community.vectorstores import Chroma
22
- from langchain_huggingface import HuggingFaceEmbeddings
23
- from langchain.chains import RetrievalQA
24
-
25
- import config # Ensure config.py has GROQ_API_KEY
26
-
27
- # Set environment variable for Groq API key
28
- os.environ["GROQ_API_KEY"] = config.GROQ_API_KEY
29
-
30
- # Setup logging
31
- logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
32
- logger = logging.getLogger(__name__)
33
-
34
- # Global variables for RAG components
35
- rag_chain = None
36
- retriever = None
37
- session_states: Dict[str, str] = {} # Store last_disease per session_id
38
-
39
-
40
- @asynccontextmanager
41
- async def lifespan(app: FastAPI):
42
- global rag_chain, retriever
43
- persist_directory = "/data/chroma_crop_rag"
44
-
45
- # Clear existing ChromaDB collection
46
- if os.path.exists(persist_directory):
47
- try:
48
- shutil.rmtree(persist_directory)
49
- logger.debug("Cleared existing ChromaDB directory: %s", persist_directory)
50
- except Exception as e:
51
- logger.error("Error clearing ChromaDB directory: %s", str(e))
52
- raise
53
-
54
- # Load JSON QA Knowledge Base
55
- try:
56
- with open("crop_disease_qa.json", "r", encoding="utf-8") as f:
57
- data = json.load(f)
58
- logger.debug("JSON loaded, length: %d", len(data))
59
- except Exception as e:
60
- logger.error("Error loading JSON: %s", str(e))
61
- raise
62
-
63
- # Convert to Documents
64
- documents = [
65
- Document(page_content=item["answer"], metadata={"question": item["question"]})
66
- for item in data
67
- ]
68
- logger.debug("Documents created: %d", len(documents))
69
-
70
- # Chunk Documents
71
- splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
72
- docs = splitter.split_documents(documents)
73
- logger.debug("Documents after splitting: %d", len(docs))
74
-
75
- # Embedding + Vectorstore
76
- try:
77
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
78
- logger.debug("Embedding model initialized")
79
- db = Chroma.from_documents(docs, embedding_model, persist_directory=persist_directory)
80
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 6})
81
- logger.debug("ChromaDB initialized")
82
- except Exception as e:
83
- logger.error("ChromaDB/Embedding error: %s", str(e))
84
- raise
85
-
86
- # Groq LLM
87
- try:
88
- llm = init_chat_model(
89
- "llama3-8b-8192",
90
- model_provider="groq",
91
- temperature=0.5
92
- )
93
- logger.debug("Groq LLM initialized")
94
- except Exception as e:
95
- logger.error("Groq LLM initialization error: %s", str(e))
96
- raise
97
-
98
- # Prompt Template
99
- prompt_template = PromptTemplate(
100
- input_variables=["context", "question"],
101
- template="""
102
- You're a friendly agricultural expert helping farmers with crop health.
103
- Answer in a warm, conversational tone, like you're chatting with a neighbor.
104
- Keep it clear, engaging, and avoid overly technical terms.
105
- Use the provided context from the knowledge base to ensure accuracy.
106
- If the question is a follow-up (e.g., 'how to treat them?', 'how to fix it?', 'what medicines should I use?'),
107
- assume it refers to the previously discussed disease or crop (e.g., Early blight in Potato) unless specified otherwise.
108
- If the context doesn't cover the question, provide a practical, general response with actionable tips.
109
- make sure the aize of the answer is not more than 100 words.
110
- Context: {context}
111
-
112
- Question: {question}
113
-
114
- Answer:
115
- """
116
- )
117
-
118
- # RAG Chain
119
- try:
120
- rag_chain = RetrievalQA.from_chain_type(
121
- llm=llm,
122
- retriever=retriever,
123
- chain_type="stuff",
124
- chain_type_kwargs={"prompt": prompt_template}
125
- )
126
- logger.debug("RAG chain initialized")
127
- except Exception as e:
128
- logger.error("RAG chain initialization error: %s", str(e))
129
- raise
130
-
131
- yield # FastAPI is now running
132
-
133
-
134
- # Initialize FastAPI with lifespan
135
- app = FastAPI(title="Crop Health Assistant API", lifespan=lifespan)
136
-
137
- # Add CORS middleware
138
- app.add_middleware(
139
- CORSMiddleware,
140
- allow_origins=["*"], # Allows all origins
141
- allow_credentials=True,
142
- allow_methods=["*"], # Allows all methods
143
- allow_headers=["*"], # Allows all headers
144
- )
145
-
146
-
147
- # Pydantic request model
148
- class QueryRequest(BaseModel):
149
- query: str
150
- session_id: str = "default"
151
-
152
-
153
- # Query endpoint
154
- @app.post("/query")
155
- async def query_crop_health(request: QueryRequest):
156
- global session_states
157
- query = request.query
158
- session_id = request.session_id
159
-
160
- if query.lower() == "exit":
161
- session_states.pop(session_id, None)
162
- return JSONResponse(content={"message": "Session ended"})
163
-
164
- # Handle follow-up queries
165
- modified_query = query
166
- last_disease = session_states.get(session_id)
167
- if last_disease and query.lower() in [
168
- "how to treat them?", "how to fix it?",
169
- "how to manage it?", "what medicines should i use?"
170
- ]:
171
- modified_query = f"What medicines or treatments for {last_disease}?"
172
-
173
- try:
174
- response = rag_chain.invoke({"query": modified_query})["result"]
175
- # Simple heuristic to update last disease
176
- if "blight" in query.lower() or "potato" in query.lower():
177
- session_states[session_id] = "Early blight in Potato"
178
-
179
- return JSONResponse(content={"question": query, "answer": response})
180
- except Exception as e:
181
- logger.error("RAG chain execution error for query '%s': %s", query, str(e))
182
- raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
183
-
184
-
185
- # Session reset endpoint
186
- @app.delete("/reset-session/{session_id}")
187
- async def reset_session(session_id: str):
188
- global session_states
189
- session_states.pop(session_id, None)
190
- return JSONResponse(content={"message": f"Session {session_id} reset"})
191
-
192
-
193
- # Run FastAPI with Uvicorn
194
- if __name__ == "__main__":
195
- import uvicorn
196
- print("Starting FastAPI server")
 
197
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import collections
2
+ from collections.abc import MutableMapping
3
+ collections.MutableMapping = MutableMapping # Patch for deprecated MutableMapping
4
+
5
+ import os
6
+ import shutil
7
+ import json
8
+ import logging
9
+ from contextlib import asynccontextmanager
10
+ from typing import Dict
11
+
12
+ from fastapi import FastAPI, HTTPException
13
+ from fastapi.responses import JSONResponse
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+
17
+ from langchain.chat_models import init_chat_model
18
+ from langchain_core.documents import Document
19
+ from langchain_core.prompts import PromptTemplate
20
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
21
+ from langchain_community.vectorstores import Chroma
22
+ from langchain_huggingface import HuggingFaceEmbeddings
23
+ from langchain.chains import RetrievalQA
24
+
25
+ import config # Ensure config.py has GROQ_API_KEY
26
+
27
+ # Set environment variable for Groq API key
28
+ os.environ["GROQ_API_KEY"] = config.GROQ_API_KEY
29
+
30
+ # Setup logging
31
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Global variables for RAG components
35
+ rag_chain = None
36
+ retriever = None
37
+ session_states: Dict[str, str] = {} # Store last_disease per session_id
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ global rag_chain, retriever
43
+ persist_directory = "/app/data/chroma_crop_rag"
44
+
45
+
46
+ # Clear existing ChromaDB collection
47
+ if os.path.exists(persist_directory):
48
+ try:
49
+ shutil.rmtree(persist_directory)
50
+ logger.debug("Cleared existing ChromaDB directory: %s", persist_directory)
51
+ except Exception as e:
52
+ logger.error("Error clearing ChromaDB directory: %s", str(e))
53
+ raise
54
+
55
+ # Load JSON QA Knowledge Base
56
+ try:
57
+ with open("crop_disease_qa.json", "r", encoding="utf-8") as f:
58
+ data = json.load(f)
59
+ logger.debug("JSON loaded, length: %d", len(data))
60
+ except Exception as e:
61
+ logger.error("Error loading JSON: %s", str(e))
62
+ raise
63
+
64
+ # Convert to Documents
65
+ documents = [
66
+ Document(page_content=item["answer"], metadata={"question": item["question"]})
67
+ for item in data
68
+ ]
69
+ logger.debug("Documents created: %d", len(documents))
70
+
71
+ # Chunk Documents
72
+ splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
73
+ docs = splitter.split_documents(documents)
74
+ logger.debug("Documents after splitting: %d", len(docs))
75
+
76
+ # Embedding + Vectorstore
77
+ try:
78
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
79
+ logger.debug("Embedding model initialized")
80
+ db = Chroma.from_documents(docs, embedding_model, persist_directory=persist_directory)
81
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 6})
82
+ logger.debug("ChromaDB initialized")
83
+ except Exception as e:
84
+ logger.error("ChromaDB/Embedding error: %s", str(e))
85
+ raise
86
+
87
+ # Groq LLM
88
+ try:
89
+ llm = init_chat_model(
90
+ "llama3-8b-8192",
91
+ model_provider="groq",
92
+ temperature=0.5
93
+ )
94
+ logger.debug("Groq LLM initialized")
95
+ except Exception as e:
96
+ logger.error("Groq LLM initialization error: %s", str(e))
97
+ raise
98
+
99
+ # Prompt Template
100
+ prompt_template = PromptTemplate(
101
+ input_variables=["context", "question"],
102
+ template="""
103
+ You're a friendly agricultural expert helping farmers with crop health.
104
+ Answer in a warm, conversational tone, like you're chatting with a neighbor.
105
+ Keep it clear, engaging, and avoid overly technical terms.
106
+ Use the provided context from the knowledge base to ensure accuracy.
107
+ If the question is a follow-up (e.g., 'how to treat them?', 'how to fix it?', 'what medicines should I use?'),
108
+ assume it refers to the previously discussed disease or crop (e.g., Early blight in Potato) unless specified otherwise.
109
+ If the context doesn't cover the question, provide a practical, general response with actionable tips.
110
+ make sure the aize of the answer is not more than 100 words.
111
+ Context: {context}
112
+
113
+ Question: {question}
114
+
115
+ Answer:
116
+ """
117
+ )
118
+
119
+ # RAG Chain
120
+ try:
121
+ rag_chain = RetrievalQA.from_chain_type(
122
+ llm=llm,
123
+ retriever=retriever,
124
+ chain_type="stuff",
125
+ chain_type_kwargs={"prompt": prompt_template}
126
+ )
127
+ logger.debug("RAG chain initialized")
128
+ except Exception as e:
129
+ logger.error("RAG chain initialization error: %s", str(e))
130
+ raise
131
+
132
+ yield # FastAPI is now running
133
+
134
+
135
+ # Initialize FastAPI with lifespan
136
+ app = FastAPI(title="Crop Health Assistant API", lifespan=lifespan)
137
+
138
+ # Add CORS middleware
139
+ app.add_middleware(
140
+ CORSMiddleware,
141
+ allow_origins=["*"], # Allows all origins
142
+ allow_credentials=True,
143
+ allow_methods=["*"], # Allows all methods
144
+ allow_headers=["*"], # Allows all headers
145
+ )
146
+
147
+
148
+ # Pydantic request model
149
+ class QueryRequest(BaseModel):
150
+ query: str
151
+ session_id: str = "default"
152
+
153
+
154
+ # Query endpoint
155
+ @app.post("/query")
156
+ async def query_crop_health(request: QueryRequest):
157
+ global session_states
158
+ query = request.query
159
+ session_id = request.session_id
160
+
161
+ if query.lower() == "exit":
162
+ session_states.pop(session_id, None)
163
+ return JSONResponse(content={"message": "Session ended"})
164
+
165
+ # Handle follow-up queries
166
+ modified_query = query
167
+ last_disease = session_states.get(session_id)
168
+ if last_disease and query.lower() in [
169
+ "how to treat them?", "how to fix it?",
170
+ "how to manage it?", "what medicines should i use?"
171
+ ]:
172
+ modified_query = f"What medicines or treatments for {last_disease}?"
173
+
174
+ try:
175
+ response = rag_chain.invoke({"query": modified_query})["result"]
176
+ # Simple heuristic to update last disease
177
+ if "blight" in query.lower() or "potato" in query.lower():
178
+ session_states[session_id] = "Early blight in Potato"
179
+
180
+ return JSONResponse(content={"question": query, "answer": response})
181
+ except Exception as e:
182
+ logger.error("RAG chain execution error for query '%s': %s", query, str(e))
183
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
184
+
185
+
186
+ # Session reset endpoint
187
+ @app.delete("/reset-session/{session_id}")
188
+ async def reset_session(session_id: str):
189
+ global session_states
190
+ session_states.pop(session_id, None)
191
+ return JSONResponse(content={"message": f"Session {session_id} reset"})
192
+
193
+
194
+ # Run FastAPI with Uvicorn
195
+ if __name__ == "__main__":
196
+ import uvicorn
197
+ print("Starting FastAPI server")
198
  uvicorn.run(app, host="0.0.0.0", port=7860)