shreyankisiri's picture
Upload main.py
967892a verified
from utils.llms import llm
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel,Field
from typing import *
import os
import uuid
import json
from langgraph.graph import StateGraph,START,END
from datetime import datetime
class DemoState(BaseModel):
query : str = Field(description="Query given by the user")
response: Optional[Any] = Field(description="response givven by the AI")
output_files_path : Optional[List[str]] = Field(description="The list of output file paths ")
template = '''
You are a Geospatial AI Task Planner designed to break down high-level geospatial analysis problems into small, manageable sub-tasks. Your goal is to create a robust, modular pipeline that solves the user's request step-by-step using available tools and external APIs if necessary.
You are equipped with the following tools:
1. **DEMFetcher**: Retrieves Digital Elevation Model (DEM) data for terrain, slope, and elevation analysis.
2. **RainfallDataFetcher**: Fetches historical and real-time rainfall datasets (e.g., GPM, IMD).
3. **HydrologyAnalyzer**: Generates hydrological outputs such as flow direction, stream network, flow accumulation, and watershed delineation from DEM.
4. **InfrastructureExtractor**: Extracts vector layers of infrastructure elements like roads, buildings, power lines, bridges.
5. **DrainageExtractor**: Extracts drainage networks (natural or man-made) from shapefiles or OSM data.
6. **FloodHistoryFetcher**: Accesses past flood maps or raster datasets (e.g., satellite-based flood extents).
7. **LULCClassifier**: Provides land use/land cover raster data for the region of interest.
8. **PopulationDensityMapper**: Maps population density from datasets like WorldPop or census raster layers.
9. **HospitalLocator**: Provides nearest hospital raster/tile (GeoTIFF) for accessibility/rescue layer analysis.
10.**BBOX_Boundary_Generator":Provides the bbox and lattitude and longitutde , along with bbox for any given place on the globe
10. **OSM_API**: Fetches any geospatial data not covered above from OpenStreetMap (e.g., water bodies, critical facilities).
---
Now, take the following **user-defined geospatial task**:
**"{query}"**
Your job is to:
1. Break this complex task into atomic, clear, and minimal steps.
2. For each step, identify the most appropriate tool from the list above.
3. If no direct tool exists, intelligently assign either "OSM_API" or "LLM Reasoning" to infer data or logic.
4. Return the response as a list of JSON objects, each with an ID, Task Description, and Tool to Use.
The output format must be:
[
{{"id": 1, "task": "Brief task description", "tool": "ToolName or LLM Reasoning"}},
{{"id": 2, "task": "Next task description", "tool": "ToolName"}},
...
]
Ensure:
- Each task is independent and logically sequenced.
- All geospatial workflows like preprocessing, data cleaning, raster/vector conversion, masking, etc., are accounted for as separate steps if needed.
- The final steps should focus on combining results and producing decision-ready maps or datasets.
Start planning now.
'''
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm | JsonOutputParser()
def reasoning_node(state:DemoState):
result = chain.invoke({"query":state.query})
print(result,type(result))
return {'response':result}
def write_to_json(state: DemoState):
# Ensure the directory exists
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
# Create a unique filename using timestamp and UUID
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = uuid.uuid4().hex[:6]
file_name = f"geospatial_plan_{timestamp}_{unique_id}.json"
file_path = os.path.join(output_dir, file_name)
# Write response to JSON
with open(file_path, "w") as f:
json.dump(state.response, f, indent=4)
print(f"Saved plan to {file_path}")
return {"output_files_path": [file_path]}
builder = StateGraph(DemoState)
builder.add_node(reasoning_node)
builder.add_node(write_to_json)
builder.add_edge(START,"reasoning_node")
builder.add_edge("reasoning_node","write_to_json")
builder.add_edge("write_to_json",END)
agent = builder.compile()
def call_agent(query):
state = DemoState(query=query, response=None,
output_files_path=[]
)
state = agent.invoke(state)
return state