Spaces:
Runtime error
Runtime error
Suchinthana
commited on
Commit
Β·
37cd808
1
Parent(s):
bc61fb2
Minimizing
Browse files
app.py
CHANGED
|
@@ -12,43 +12,23 @@ from diffusers import StableDiffusionInpaintPipeline
|
|
| 12 |
import spaces
|
| 13 |
import logging
|
| 14 |
import math
|
| 15 |
-
from typing import List, Union
|
| 16 |
|
| 17 |
# Set up logging
|
| 18 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
logger.info("Script starting. Initializing APIs and models.")
|
| 22 |
-
|
| 23 |
# Initialize APIs
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
logger.info("OpenAI client initialized.")
|
| 27 |
-
except KeyError:
|
| 28 |
-
logger.error("OPENAI_API_KEY environment variable not set!")
|
| 29 |
-
# Handle this critical error, perhaps exit or raise
|
| 30 |
-
raise
|
| 31 |
-
except Exception as e:
|
| 32 |
-
logger.error(f"Error initializing OpenAI client: {e}")
|
| 33 |
-
raise
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
geolocator = Nominatim(user_agent="geoapi_visualizemap") # More specific user agent
|
| 37 |
-
logger.info("Geolocator initialized.")
|
| 38 |
-
except Exception as e:
|
| 39 |
-
logger.error(f"Error initializing Geolocator: {e}")
|
| 40 |
-
raise
|
| 41 |
|
| 42 |
# Function to fetch coordinates
|
| 43 |
@spaces.GPU
|
| 44 |
def get_geo_coordinates(location_name):
|
| 45 |
-
logger.info(f"Attempting to fetch coordinates for: {location_name}")
|
| 46 |
try:
|
| 47 |
-
location = geolocator.geocode(location_name
|
| 48 |
if location:
|
| 49 |
-
logger.info(f"Coordinates found for {location_name}: {[location.longitude, location.latitude]}")
|
| 50 |
return [location.longitude, location.latitude]
|
| 51 |
-
logger.warning(f"No location data returned for {location_name}")
|
| 52 |
return None
|
| 53 |
except Exception as e:
|
| 54 |
logger.error(f"Error fetching coordinates for {location_name}: {e}")
|
|
@@ -57,14 +37,12 @@ def get_geo_coordinates(location_name):
|
|
| 57 |
# Function to process OpenAI chat response
|
| 58 |
@spaces.GPU
|
| 59 |
def process_openai_response(query):
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
"role": "system",
|
| 67 |
-
"content": """
|
| 68 |
You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
|
| 69 |
|
| 70 |
1. The JSON should always have the following structure:
|
|
@@ -115,410 +93,207 @@ You are an assistant that generates structured JSON output for geographical quer
|
|
| 115 |
|
| 116 |
Generate similar JSON for the following query:
|
| 117 |
"""
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
logger.info(f"Raw OpenAI response content: {content}")
|
| 133 |
-
parsed_response = json.loads(content)
|
| 134 |
-
logger.info(f"Parsed OpenAI response: {json.dumps(parsed_response, indent=2)}")
|
| 135 |
-
return parsed_response
|
| 136 |
-
except Exception as e:
|
| 137 |
-
logger.error(f"Error processing OpenAI response for query '{query}': {e}")
|
| 138 |
-
# Consider returning a default error structure or re-raising
|
| 139 |
-
raise
|
| 140 |
|
| 141 |
# Generate GeoJSON from OpenAI response
|
| 142 |
@spaces.GPU
|
| 143 |
-
def generate_geojson(
|
| 144 |
-
logger.info(f"
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
coord = get_geo_coordinates(city)
|
| 154 |
if coord:
|
| 155 |
coordinates.append(coord)
|
| 156 |
else:
|
| 157 |
-
logger.warning(f"Coordinates not found for city: {city}
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
geojson_data = {
|
| 185 |
-
"type": "FeatureCollection",
|
| 186 |
-
"features": [
|
| 187 |
-
{
|
| 188 |
-
"type": "Feature",
|
| 189 |
-
"properties": properties,
|
| 190 |
-
"geometry": {
|
| 191 |
-
"type": feature_type,
|
| 192 |
-
"coordinates": final_coordinates,
|
| 193 |
-
},
|
| 194 |
-
}
|
| 195 |
-
],
|
| 196 |
-
}
|
| 197 |
-
logger.info(f"Generated GeoJSON: {json.dumps(geojson_data, indent=2)}")
|
| 198 |
-
return geojson_data
|
| 199 |
-
except KeyError as e:
|
| 200 |
-
logger.error(f"KeyError while generating GeoJSON: {e}. Response data: {json.dumps(response_data, indent=2)}")
|
| 201 |
-
raise
|
| 202 |
-
except ValueError as e:
|
| 203 |
-
logger.error(f"ValueError while generating GeoJSON: {e}. Coordinates: {coordinates if 'coordinates' in locals() else 'N/A'}")
|
| 204 |
-
raise
|
| 205 |
-
except Exception as e:
|
| 206 |
-
logger.error(f"Unexpected error in generate_geojson: {e}")
|
| 207 |
-
raise
|
| 208 |
|
| 209 |
# Sort coordinates for a simple polygon (Reduce intersection points)
|
| 210 |
def sort_coordinates_for_simple_polygon(geojson):
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
return math.atan2(dy, dx)
|
| 239 |
-
|
| 240 |
-
sorted_plot_coordinates = sorted(plot_coordinates, key=angle_from_centroid)
|
| 241 |
-
sorted_plot_coordinates.append(sorted_plot_coordinates[0]) # Close the polygon
|
| 242 |
-
|
| 243 |
-
geojson['features'][0]['geometry']['coordinates'][0] = sorted_plot_coordinates
|
| 244 |
-
logger.info(f"Sorted polygon coordinates: {sorted_plot_coordinates}")
|
| 245 |
-
return geojson
|
| 246 |
-
except Exception as e:
|
| 247 |
-
logger.error(f"Error sorting polygon coordinates: {e}")
|
| 248 |
-
return geojson # Return original on error
|
| 249 |
|
| 250 |
# Generate static map image
|
| 251 |
@spaces.GPU
|
| 252 |
def generate_static_map(geojson_data, invisible=False):
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
else
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
# Coords for MultiPoint is a list of [lon, lat]
|
| 271 |
-
for coord_pair in coords:
|
| 272 |
-
if coord_pair and len(coord_pair) == 2 and isinstance(coord_pair[0], (int, float)):
|
| 273 |
-
m.add_marker(CircleMarker((coord_pair[0], coord_pair[1]), color, 20 if invisible else 10))
|
| 274 |
-
else:
|
| 275 |
-
logger.warning(f"Skipping point in MultiPoint due to invalid coordinate structure: {coord_pair}")
|
| 276 |
-
elif geom_type == "LineString":
|
| 277 |
-
# Coords for LineString is a list of [lon, lat]
|
| 278 |
-
if len(coords) >=2:
|
| 279 |
-
m.add_line(Polygon([(c[0], c[1]) for c in coords], "blue", 3)) # For LineString, use add_line or thicker Polygon outline
|
| 280 |
-
else:
|
| 281 |
-
logger.warning(f"Skipping LineString, not enough points: {coords}")
|
| 282 |
-
elif geom_type == "Polygon":
|
| 283 |
-
# Coords for Polygon is a list containing one list of [lon, lat] (the exterior ring)
|
| 284 |
-
for polygon_ring in coords: # Should be only one for simple polygon
|
| 285 |
-
if len(polygon_ring) >= 3:
|
| 286 |
-
m.add_polygon(Polygon([(c[0], c[1]) for c in polygon_ring], color, '#0000AA' if not invisible else '#1C00ff00', 3 if not invisible else 0))
|
| 287 |
-
else:
|
| 288 |
-
logger.warning(f"Skipping polygon ring, not enough points: {polygon_ring}")
|
| 289 |
-
# Add handling for MultiLineString, MultiPolygon if your OpenAI might produce them
|
| 290 |
-
else:
|
| 291 |
-
logger.warning(f"Unsupported geometry type for static map: {geom_type}")
|
| 292 |
-
|
| 293 |
-
rendered_map = m.render(center=None, zoom=None) # Let it auto-center and zoom
|
| 294 |
-
logger.info(f"Static map rendered successfully. Invisible: {invisible}")
|
| 295 |
-
return rendered_map
|
| 296 |
-
except Exception as e:
|
| 297 |
-
logger.error(f"Error generating static map (invisible={invisible}): {e}")
|
| 298 |
-
# Return a placeholder or re-raise
|
| 299 |
-
return Image.new("RGB", (600, 600), color="grey") # Placeholder
|
| 300 |
|
| 301 |
# ControlNet pipeline setup
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
except Exception as e:
|
| 315 |
-
logger.error(f"Error initializing Stable Diffusion pipeline: {e}")
|
| 316 |
-
raise
|
| 317 |
-
|
| 318 |
-
# This function was for ControlNet, may not be needed as-is for StableDiffusionInpaintPipeline
|
| 319 |
-
# It expects init_image to be a NumPy array, and mask_image a NumPy array
|
| 320 |
@spaces.GPU
|
| 321 |
-
def make_inpaint_condition(
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
init_image_np = np.array(init_image_pil.convert("RGB")).astype(np.float32) / 255.0
|
| 325 |
-
mask_image_np = np.array(mask_image_pil.convert("L")).astype(np.float32) / 255.0 # Ensure mask is L
|
| 326 |
-
|
| 327 |
-
logger.info(f"Init image shape: {init_image_np.shape}, Mask image shape: {mask_image_np.shape}")
|
| 328 |
-
|
| 329 |
-
if init_image_np.shape[:2] != mask_image_np.shape[:2]:
|
| 330 |
-
logger.error(f"Image and mask dimensions mismatch: {init_image_np.shape[:2]} vs {mask_image_np.shape[:2]}")
|
| 331 |
-
# Resize mask to match image if necessary, or raise error
|
| 332 |
-
# For now, let's assume they should match and this is an error state
|
| 333 |
-
raise ValueError("Image and mask_image must have the same height and width.")
|
| 334 |
-
|
| 335 |
-
# This operation is specific to how some ControlNet inpainting expects masked areas.
|
| 336 |
-
# Standard SDInpaintPipeline might not need this.
|
| 337 |
-
# init_image_np[mask_image_np > 0.5] = -1.0 # set as masked pixel
|
| 338 |
-
|
| 339 |
-
# init_image_np = np.expand_dims(init_image_np, 0).transpose(0, 3, 1, 2)
|
| 340 |
-
# init_image_tensor = torch.from_numpy(init_image_np)
|
| 341 |
-
# logger.info(f"Processed init_image tensor shape: {init_image_tensor.shape}")
|
| 342 |
-
# return init_image_tensor
|
| 343 |
-
|
| 344 |
-
# For StableDiffusionInpaintPipeline, `image` and `mask_image` are passed directly as PIL Images or tensors.
|
| 345 |
-
# The `make_inpaint_condition` might be redundant if you are not using a ControlNet that specifically requires this format.
|
| 346 |
-
# If you were using ControlNet, this would be the control_image.
|
| 347 |
-
# For now, let's assume it's meant to be the 'image' input for SD Inpaint, preprocessed.
|
| 348 |
-
return init_image_pil # Or init_image_tensor if pipeline expects tensor
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
@spaces.GPU
|
| 352 |
-
def generate_satellite_image(
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
# prompt=prompt,
|
| 365 |
-
# image=base_image_pil, # or tensor version if pipeline prefers
|
| 366 |
-
# mask_image=mask_image_pil, # or tensor version
|
| 367 |
-
# control_image=control_image_tensor, # This is for ControlNet
|
| 368 |
-
# strength=0.47, # strength might be called differently or not used in SD Inpaint
|
| 369 |
-
# guidance_scale=9.5, # Adjusted scale
|
| 370 |
-
# num_inference_steps=50 # Adjusted steps
|
| 371 |
-
# ).images[0]
|
| 372 |
-
|
| 373 |
-
# For StableDiffusionInpaintPipeline:
|
| 374 |
-
result = pipeline(
|
| 375 |
-
prompt=prompt,
|
| 376 |
-
image=base_image_pil, # PIL Image or PyTorch tensor
|
| 377 |
-
mask_image=mask_image_pil, # PIL Image or PyTorch tensor
|
| 378 |
-
guidance_scale=9.5, # More reasonable default
|
| 379 |
-
num_inference_steps=50 # More reasonable default
|
| 380 |
-
).images[0]
|
| 381 |
-
|
| 382 |
-
logger.info("Satellite image generated successfully.")
|
| 383 |
-
return result
|
| 384 |
-
except Exception as e:
|
| 385 |
-
logger.error(f"Error generating satellite image: {e}")
|
| 386 |
-
return Image.new("RGB", base_image_pil.size, color="red") # Placeholder
|
| 387 |
|
| 388 |
# Gradio UI
|
| 389 |
@spaces.GPU
|
| 390 |
-
def handle_query(query
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
empty_map_image
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
threshold = 10 # May need adjustment
|
| 416 |
-
mask_array = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
|
| 417 |
-
mask_image = Image.fromarray(mask_array, mode="L")
|
| 418 |
-
logger.info(f"handle_query: Mask image generated: type={type(mask_image)}")
|
| 419 |
-
|
| 420 |
-
prompt_for_image = openai_response['output']['feature_representation']['properties']['description']
|
| 421 |
-
logger.info(f"handle_query: Prompt for satellite image: '{prompt_for_image}', type={type(prompt_for_image)}")
|
| 422 |
-
|
| 423 |
-
# Pass empty_map_image (which is the base map without visible markers)
|
| 424 |
-
# and the derived mask_image to the inpainting function
|
| 425 |
-
satellite_image = generate_satellite_image(
|
| 426 |
-
empty_map_image, mask_image, prompt_for_image
|
| 427 |
-
)
|
| 428 |
-
logger.info(f"handle_query: Satellite image generated: type={type(satellite_image)}")
|
| 429 |
-
|
| 430 |
-
# Ensure all returned image types are PIL Images
|
| 431 |
-
final_map_image = map_image if isinstance(map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
|
| 432 |
-
final_satellite_image = satellite_image if isinstance(satellite_image, Image.Image) else Image.new("RGB", (600,600), "red")
|
| 433 |
-
final_empty_map_image = empty_map_image if isinstance(empty_map_image, Image.Image) else Image.new("RGB", (600,600), "grey")
|
| 434 |
-
final_mask_image = mask_image if isinstance(mask_image, Image.Image) else Image.new("L", (600,600), 0)
|
| 435 |
-
|
| 436 |
-
logger.info(f"handle_query: Returning types: {type(final_map_image)}, {type(final_satellite_image)}, {type(final_empty_map_image)}, {type(final_mask_image)}, {type(prompt_for_image)}")
|
| 437 |
-
return final_map_image, final_satellite_image, final_empty_map_image, final_mask_image, prompt_for_image
|
| 438 |
|
| 439 |
-
except Exception as e:
|
| 440 |
-
logger.error(f"--- Error in handle_query for query '{query}': {e} ---", exc_info=True)
|
| 441 |
-
# Return placeholder/error images and message
|
| 442 |
-
error_img = Image.new("RGB", (600, 600), "black")
|
| 443 |
-
error_text_img = ImageDraw.Draw(error_img)
|
| 444 |
-
error_text_img.text((10,10), f"Error: {e}", fill="white")
|
| 445 |
-
return error_img, error_img, error_img, error_img, f"Error processing query: {e}"
|
| 446 |
-
|
| 447 |
-
def update_query(selected_query_value: str) -> str: # Added type hints
|
| 448 |
-
logger.info(f"Dropdown changed. Selected query: '{selected_query_value}', type: {type(selected_query_value)}")
|
| 449 |
-
return selected_query_value
|
| 450 |
-
|
| 451 |
-
logger.info("Defining Gradio UI components.")
|
| 452 |
query_options = [
|
| 453 |
"Area covering south asian subcontinent",
|
| 454 |
-
"Mark a triangular area using New York, Boston, and Texas",
|
| 455 |
"Mark cities in India",
|
| 456 |
"Show me Lotus Tower in a Map",
|
| 457 |
"Mark the area of west germany",
|
| 458 |
"Mark the area of the Amazon rainforest",
|
| 459 |
"Mark the area of the Sahara desert"
|
| 460 |
-
]
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
query_input = gr.Textbox(label="Enter Query", value=str(query_options[-1])) # Ensure value is string
|
| 477 |
-
logger.info(f"query_input Textbox defined. Initial value: '{query_options[-1]}', type: {type(query_options[-1])}")
|
| 478 |
-
|
| 479 |
-
# The `change` event should not cause the schema error, but good to log
|
| 480 |
-
selected_query.change(fn=update_query, inputs=selected_query, outputs=query_input)
|
| 481 |
-
logger.info("selected_query.change event defined.")
|
| 482 |
-
|
| 483 |
-
submit_btn = gr.Button("Submit")
|
| 484 |
-
logger.info("submit_btn Button defined.")
|
| 485 |
-
|
| 486 |
-
with gr.Row():
|
| 487 |
-
logger.info("Defining second gr.Row for image outputs.")
|
| 488 |
-
map_output = gr.Image(label="Map Visualization") # No initial value needed here, will be populated by function
|
| 489 |
-
logger.info("map_output Image defined.")
|
| 490 |
-
satellite_output = gr.Image(label="Generated Map Image")
|
| 491 |
-
logger.info("satellite_output Image defined.")
|
| 492 |
-
|
| 493 |
-
with gr.Row():
|
| 494 |
-
logger.info("Defining third gr.Row for debug outputs.")
|
| 495 |
-
empty_map_output = gr.Image(label="Empty Visualization")
|
| 496 |
-
logger.info("empty_map_output Image defined.")
|
| 497 |
-
mask_output = gr.Image(label="Mask")
|
| 498 |
-
logger.info("mask_output Image defined.")
|
| 499 |
-
# For image_prompt, provide a default string value or None. An empty string is fine.
|
| 500 |
-
image_prompt_output = gr.Textbox(label="Image Prompt Used", value="") # Changed name to avoid conflict, ensure string value
|
| 501 |
-
logger.info(f"image_prompt_output Textbox defined. Initial value: '', type: str")
|
| 502 |
-
|
| 503 |
-
# The outputs list must match the number and expected types of what handle_query returns.
|
| 504 |
-
# handle_query returns: PIL.Image, PIL.Image, PIL.Image, PIL.Image, str
|
| 505 |
-
# Gradio components: gr.Image, gr.Image, gr.Image, gr.Image, gr.Textbox
|
| 506 |
-
# This mapping looks correct.
|
| 507 |
-
submit_btn.click(fn=handle_query,
|
| 508 |
-
inputs=[query_input],
|
| 509 |
-
outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt_output])
|
| 510 |
-
logger.info("submit_btn.click event defined.")
|
| 511 |
-
logger.info("Gradio Blocks defined successfully.")
|
| 512 |
-
|
| 513 |
-
except Exception as e:
|
| 514 |
-
logger.error(f"Error during Gradio UI definition: {e}", exc_info=True)
|
| 515 |
-
raise
|
| 516 |
|
| 517 |
if __name__ == "__main__":
|
| 518 |
-
|
| 519 |
-
try:
|
| 520 |
-
demo.launch() # debug=True can sometimes give more frontend info, but not for this backend error
|
| 521 |
-
logger.info("Gradio demo launched.")
|
| 522 |
-
except Exception as e:
|
| 523 |
-
logger.error(f"Error launching Gradio demo: {e}", exc_info=True)
|
| 524 |
-
raise
|
|
|
|
| 12 |
import spaces
|
| 13 |
import logging
|
| 14 |
import math
|
| 15 |
+
from typing import List, Union
|
| 16 |
|
| 17 |
# Set up logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
|
|
|
|
|
| 21 |
# Initialize APIs
|
| 22 |
+
openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
|
| 23 |
+
geolocator = Nominatim(user_agent="geoapi")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Function to fetch coordinates
|
| 26 |
@spaces.GPU
|
| 27 |
def get_geo_coordinates(location_name):
|
|
|
|
| 28 |
try:
|
| 29 |
+
location = geolocator.geocode(location_name)
|
| 30 |
if location:
|
|
|
|
| 31 |
return [location.longitude, location.latitude]
|
|
|
|
| 32 |
return None
|
| 33 |
except Exception as e:
|
| 34 |
logger.error(f"Error fetching coordinates for {location_name}: {e}")
|
|
|
|
| 37 |
# Function to process OpenAI chat response
|
| 38 |
@spaces.GPU
|
| 39 |
def process_openai_response(query):
|
| 40 |
+
response = openai_client.chat.completions.create(
|
| 41 |
+
model="gpt-4o-mini",
|
| 42 |
+
messages=[
|
| 43 |
+
{
|
| 44 |
+
"role": "system",
|
| 45 |
+
"content": """
|
|
|
|
|
|
|
| 46 |
You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules:
|
| 47 |
|
| 48 |
1. The JSON should always have the following structure:
|
|
|
|
| 93 |
|
| 94 |
Generate similar JSON for the following query:
|
| 95 |
"""
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"role": "user",
|
| 99 |
+
"content": query
|
| 100 |
+
}
|
| 101 |
+
],
|
| 102 |
+
temperature=1,
|
| 103 |
+
max_tokens=2048,
|
| 104 |
+
top_p=1,
|
| 105 |
+
frequency_penalty=0,
|
| 106 |
+
presence_penalty=0,
|
| 107 |
+
response_format={"type": "json_object"}
|
| 108 |
+
)
|
| 109 |
+
return json.loads(response.choices[0].message.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Generate GeoJSON from OpenAI response
|
| 112 |
@spaces.GPU
|
| 113 |
+
def generate_geojson(response):
|
| 114 |
+
logger.info(f"OpenAI response: {response}")
|
| 115 |
+
feature_type = response['output']['feature_representation']['type']
|
| 116 |
+
city_names = response['output']['feature_representation']['cities']
|
| 117 |
+
properties = response['output']['feature_representation']['properties']
|
| 118 |
+
|
| 119 |
+
coordinates = []
|
| 120 |
|
| 121 |
+
# Fetch coordinates for cities
|
| 122 |
+
for city in city_names:
|
| 123 |
+
try:
|
| 124 |
coord = get_geo_coordinates(city)
|
| 125 |
if coord:
|
| 126 |
coordinates.append(coord)
|
| 127 |
else:
|
| 128 |
+
logger.warning(f"Coordinates not found for city: {city}")
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error fetching coordinates for {city}: {e}")
|
| 131 |
+
|
| 132 |
+
if feature_type == "Polygon":
|
| 133 |
+
if len(coordinates) < 3:
|
| 134 |
+
raise ValueError("Polygon requires at least 3 coordinates.")
|
| 135 |
+
# Close the polygon by appending the first point at the end
|
| 136 |
+
coordinates.append(coordinates[0])
|
| 137 |
+
coordinates = [coordinates] # Nest coordinates for Polygon
|
| 138 |
+
|
| 139 |
+
# Create the GeoJSON object
|
| 140 |
+
geojson_data = {
|
| 141 |
+
"type": "FeatureCollection",
|
| 142 |
+
"features": [
|
| 143 |
+
{
|
| 144 |
+
"type": "Feature",
|
| 145 |
+
"properties": properties,
|
| 146 |
+
"geometry": {
|
| 147 |
+
"type": feature_type,
|
| 148 |
+
"coordinates": coordinates,
|
| 149 |
+
},
|
| 150 |
+
}
|
| 151 |
+
],
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
return geojson_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
# Sort coordinates for a simple polygon (Reduce intersection points)
|
| 157 |
def sort_coordinates_for_simple_polygon(geojson):
|
| 158 |
+
# Extract coordinates from the GeoJSON
|
| 159 |
+
coordinates = geojson['features'][0]['geometry']['coordinates'][0]
|
| 160 |
+
|
| 161 |
+
# Remove the last point if it duplicates the first (GeoJSON convention for polygons)
|
| 162 |
+
if coordinates[0] == coordinates[-1]:
|
| 163 |
+
coordinates = coordinates[:-1]
|
| 164 |
+
|
| 165 |
+
# Calculate the centroid of the points
|
| 166 |
+
centroid_x = sum(point[0] for point in coordinates) / len(coordinates)
|
| 167 |
+
centroid_y = sum(point[1] for point in coordinates) / len(coordinates)
|
| 168 |
+
|
| 169 |
+
# Define a function to calculate the angle relative to the centroid
|
| 170 |
+
def angle_from_centroid(point):
|
| 171 |
+
dx = point[0] - centroid_x
|
| 172 |
+
dy = point[1] - centroid_y
|
| 173 |
+
return math.atan2(dy, dx)
|
| 174 |
+
|
| 175 |
+
# Sort points by their angle from the centroid
|
| 176 |
+
sorted_coordinates = sorted(coordinates, key=angle_from_centroid)
|
| 177 |
+
|
| 178 |
+
# Close the polygon by appending the first point to the end
|
| 179 |
+
sorted_coordinates.append(sorted_coordinates[0])
|
| 180 |
+
|
| 181 |
+
# Update the GeoJSON with sorted coordinates
|
| 182 |
+
geojson['features'][0]['geometry']['coordinates'][0] = sorted_coordinates
|
| 183 |
+
|
| 184 |
+
return geojson
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# Generate static map image
|
| 187 |
@spaces.GPU
|
| 188 |
def generate_static_map(geojson_data, invisible=False):
|
| 189 |
+
m = StaticMap(600, 600)
|
| 190 |
+
logger.info(f"GeoJSON data: {geojson_data}")
|
| 191 |
+
|
| 192 |
+
for feature in geojson_data["features"]:
|
| 193 |
+
geom_type = feature["geometry"]["type"]
|
| 194 |
+
coords = feature["geometry"]["coordinates"]
|
| 195 |
+
|
| 196 |
+
if geom_type == "Point":
|
| 197 |
+
m.add_marker(CircleMarker((coords[0][0], coords[0][1]), '#1C00ff00' if invisible else '#42445A85', 100))
|
| 198 |
+
elif geom_type in ["MultiPoint", "LineString"]:
|
| 199 |
+
for coord in coords:
|
| 200 |
+
m.add_marker(CircleMarker((coord[0], coord[1]), '#1C00ff00' if invisible else '#42445A85', 100))
|
| 201 |
+
elif geom_type in ["Polygon", "MultiPolygon"]:
|
| 202 |
+
for polygon in coords:
|
| 203 |
+
m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], '#1C00ff00' if invisible else '#42445A85', 3))
|
| 204 |
+
|
| 205 |
+
return m.render()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
# ControlNet pipeline setup
|
| 208 |
+
# controlnet = ControlNetModel.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
|
| 209 |
+
# pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
| 210 |
+
# "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
|
| 211 |
+
# )
|
| 212 |
+
# pipeline.to('cuda')
|
| 213 |
+
|
| 214 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 215 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
| 216 |
+
torch_dtype=torch.float16,
|
| 217 |
+
)
|
| 218 |
+
pipeline.to("cuda")
|
| 219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
@spaces.GPU
|
| 221 |
+
def make_inpaint_condition(init_image, mask_image):
|
| 222 |
+
init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
|
| 223 |
+
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
|
| 226 |
+
init_image[mask_image > 0.5] = -1.0 # set as masked pixel
|
| 227 |
+
init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
|
| 228 |
+
init_image = torch.from_numpy(init_image)
|
| 229 |
+
return init_image
|
| 230 |
|
| 231 |
@spaces.GPU
|
| 232 |
+
def generate_satellite_image(init_image, mask_image, prompt):
|
| 233 |
+
control_image = make_inpaint_condition(init_image, mask_image)
|
| 234 |
+
result = pipeline(
|
| 235 |
+
prompt=prompt,
|
| 236 |
+
image=init_image,
|
| 237 |
+
mask_image=mask_image,
|
| 238 |
+
control_image=control_image,
|
| 239 |
+
strength=0.47,
|
| 240 |
+
guidance_scale=95,
|
| 241 |
+
num_inference_steps=250
|
| 242 |
+
)
|
| 243 |
+
return result.images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Gradio UI
|
| 246 |
@spaces.GPU
|
| 247 |
+
def handle_query(query):
|
| 248 |
+
response = process_openai_response(query)
|
| 249 |
+
geojson_data = generate_geojson(response)
|
| 250 |
+
|
| 251 |
+
if geojson_data["features"][0]["geometry"]["type"] == 'Polygon':
|
| 252 |
+
geojson_data_coords = sort_coordinates_for_simple_polygon(geojson_data)
|
| 253 |
+
map_image = generate_static_map(geojson_data_coords)
|
| 254 |
+
else:
|
| 255 |
+
map_image = generate_static_map(geojson_data)
|
| 256 |
+
empty_map_image = generate_static_map(geojson_data, invisible=True)
|
| 257 |
+
|
| 258 |
+
difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB")))
|
| 259 |
+
threshold = 10
|
| 260 |
+
mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
|
| 261 |
+
|
| 262 |
+
mask_image = Image.fromarray(mask, mode="L")
|
| 263 |
+
satellite_image = generate_satellite_image(
|
| 264 |
+
empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return map_image, satellite_image, empty_map_image, mask_image, response
|
| 268 |
+
#return map_image, satellite_image, empty_map_image, mask_image, response['output']['feature_representation']['properties']['description']
|
| 269 |
+
|
| 270 |
+
def update_query(selected_query):
|
| 271 |
+
return [selected_query]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
query_options = [
|
| 274 |
"Area covering south asian subcontinent",
|
| 275 |
+
"Mark a triangular area using New York, Boston, and Texas",
|
| 276 |
"Mark cities in India",
|
| 277 |
"Show me Lotus Tower in a Map",
|
| 278 |
"Mark the area of west germany",
|
| 279 |
"Mark the area of the Amazon rainforest",
|
| 280 |
"Mark the area of the Sahara desert"
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
with gr.Blocks() as demo:
|
| 284 |
+
with gr.Row():
|
| 285 |
+
selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1])
|
| 286 |
+
query_input = gr.Textbox(label="Enter Query", value=query_options[-1])
|
| 287 |
+
selected_query.change(update_query, inputs=selected_query, outputs=query_input)
|
| 288 |
+
submit_btn = gr.Button("Submit")
|
| 289 |
+
with gr.Row():
|
| 290 |
+
map_output = gr.Image(label="Map Visualization")
|
| 291 |
+
satellite_output = gr.Image(label="Generated Map Image")
|
| 292 |
+
with gr.Row():
|
| 293 |
+
empty_map_output = gr.Image(label="Empty Visualization")
|
| 294 |
+
mask_output = gr.Image(label="Mask")
|
| 295 |
+
image_prompt = gr.Textbox(label="Image Prompt Used")
|
| 296 |
+
submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
if __name__ == "__main__":
|
| 299 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|