Spaces:
Runtime error
Runtime error
Arnab Dey commited on
Commit ·
4680c6a
1
Parent(s): 282197c
Add home point and dot size features to poster generation
Browse files- app.py +65 -1
- create_map_poster.py +30 -0
app.py
CHANGED
|
@@ -26,6 +26,9 @@ MAX_DISTANCE_M = 20000
|
|
| 26 |
DEFAULT_DPI = 300
|
| 27 |
MIN_DPI = 150
|
| 28 |
MAX_DPI = 600
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
class NetworkType(str, Enum):
|
|
@@ -57,6 +60,8 @@ class GenerateRequest(BaseModel):
|
|
| 57 |
dpi: int
|
| 58 |
network_type: NetworkType
|
| 59 |
dist_type: DistanceType
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@field_validator("city", "country", "theme")
|
| 62 |
@classmethod
|
|
@@ -106,6 +111,33 @@ class GenerateRequest(BaseModel):
|
|
| 106 |
except ValueError as exc:
|
| 107 |
raise ValueError("Invalid distance type.") from exc
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def _load_readme_example_posters() -> list[tuple[str, str]]:
|
| 111 |
"""Return (absolute_path, caption) pairs for example posters referenced in README."""
|
|
@@ -214,6 +246,8 @@ def generate(
|
|
| 214 |
dpi: int,
|
| 215 |
network_type: str,
|
| 216 |
dist_type: str,
|
|
|
|
|
|
|
| 217 |
) -> str:
|
| 218 |
try:
|
| 219 |
request = GenerateRequest(
|
|
@@ -224,10 +258,17 @@ def generate(
|
|
| 224 |
dpi=dpi,
|
| 225 |
network_type=network_type,
|
| 226 |
dist_type=dist_type,
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
except ValidationError as exc:
|
| 229 |
raise gr.Error(str(exc))
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
available_themes = maptoposter.get_available_themes()
|
| 232 |
if request.theme not in available_themes:
|
| 233 |
raise gr.Error(f"Unknown theme: {request.theme}")
|
|
@@ -251,6 +292,8 @@ def generate(
|
|
| 251 |
network_type=request.network_type.value,
|
| 252 |
dist_type=request.dist_type.value,
|
| 253 |
dpi=request.dpi,
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
return output_path
|
| 256 |
|
|
@@ -426,6 +469,17 @@ def build_demo() -> gr.Blocks:
|
|
| 426 |
choices=DIST_TYPES,
|
| 427 |
value="bbox",
|
| 428 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
btn = gr.Button("Generate poster", elem_classes=["mtp-primary"])
|
| 431 |
gr.HTML(
|
|
@@ -438,7 +492,17 @@ def build_demo() -> gr.Blocks:
|
|
| 438 |
|
| 439 |
btn.click(
|
| 440 |
generate,
|
| 441 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
outputs=[out],
|
| 443 |
)
|
| 444 |
|
|
|
|
| 26 |
DEFAULT_DPI = 300
|
| 27 |
MIN_DPI = 150
|
| 28 |
MAX_DPI = 600
|
| 29 |
+
DEFAULT_DOT_SIZE = 60
|
| 30 |
+
MIN_DOT_SIZE = 10
|
| 31 |
+
MAX_DOT_SIZE = 300
|
| 32 |
|
| 33 |
|
| 34 |
class NetworkType(str, Enum):
|
|
|
|
| 60 |
dpi: int
|
| 61 |
network_type: NetworkType
|
| 62 |
dist_type: DistanceType
|
| 63 |
+
home_point: str | None = None
|
| 64 |
+
dot_size: float | None = None
|
| 65 |
|
| 66 |
@field_validator("city", "country", "theme")
|
| 67 |
@classmethod
|
|
|
|
| 111 |
except ValueError as exc:
|
| 112 |
raise ValueError("Invalid distance type.") from exc
|
| 113 |
|
| 114 |
+
@field_validator("dot_size", mode="before")
|
| 115 |
+
@classmethod
|
| 116 |
+
def _validate_dot_size(cls, value: str | float | None) -> float | None:
|
| 117 |
+
if value is None or value == "":
|
| 118 |
+
return None
|
| 119 |
+
value = float(value)
|
| 120 |
+
if value < MIN_DOT_SIZE or value > MAX_DOT_SIZE:
|
| 121 |
+
raise ValueError(f"Dot size must be between {MIN_DOT_SIZE} and {MAX_DOT_SIZE}.")
|
| 122 |
+
return value
|
| 123 |
+
|
| 124 |
+
def _parse_home_point(value: str | None) -> tuple[float, float] | None:
|
| 125 |
+
if not value:
|
| 126 |
+
return None
|
| 127 |
+
cleaned = value.strip()
|
| 128 |
+
if not cleaned:
|
| 129 |
+
return None
|
| 130 |
+
parts = [p.strip() for p in cleaned.split(",")]
|
| 131 |
+
if len(parts) != 2:
|
| 132 |
+
raise ValueError("Home point must be in 'lat, lon' format.")
|
| 133 |
+
lat = float(parts[0])
|
| 134 |
+
lon = float(parts[1])
|
| 135 |
+
if lat < -90 or lat > 90:
|
| 136 |
+
raise ValueError("Home point latitude must be between -90 and 90.")
|
| 137 |
+
if lon < -180 or lon > 180:
|
| 138 |
+
raise ValueError("Home point longitude must be between -180 and 180.")
|
| 139 |
+
return (lat, lon)
|
| 140 |
+
|
| 141 |
|
| 142 |
def _load_readme_example_posters() -> list[tuple[str, str]]:
|
| 143 |
"""Return (absolute_path, caption) pairs for example posters referenced in README."""
|
|
|
|
| 246 |
dpi: int,
|
| 247 |
network_type: str,
|
| 248 |
dist_type: str,
|
| 249 |
+
home_point: str | None,
|
| 250 |
+
dot_size: float | None,
|
| 251 |
) -> str:
|
| 252 |
try:
|
| 253 |
request = GenerateRequest(
|
|
|
|
| 258 |
dpi=dpi,
|
| 259 |
network_type=network_type,
|
| 260 |
dist_type=dist_type,
|
| 261 |
+
home_point=home_point,
|
| 262 |
+
dot_size=dot_size,
|
| 263 |
)
|
| 264 |
except ValidationError as exc:
|
| 265 |
raise gr.Error(str(exc))
|
| 266 |
|
| 267 |
+
try:
|
| 268 |
+
dot_coords = _parse_home_point(request.home_point)
|
| 269 |
+
except ValueError as exc:
|
| 270 |
+
raise gr.Error(str(exc))
|
| 271 |
+
|
| 272 |
available_themes = maptoposter.get_available_themes()
|
| 273 |
if request.theme not in available_themes:
|
| 274 |
raise gr.Error(f"Unknown theme: {request.theme}")
|
|
|
|
| 292 |
network_type=request.network_type.value,
|
| 293 |
dist_type=request.dist_type.value,
|
| 294 |
dpi=request.dpi,
|
| 295 |
+
dot=dot_coords,
|
| 296 |
+
dot_size=request.dot_size if request.dot_size is not None else DEFAULT_DOT_SIZE,
|
| 297 |
)
|
| 298 |
return output_path
|
| 299 |
|
|
|
|
| 469 |
choices=DIST_TYPES,
|
| 470 |
value="bbox",
|
| 471 |
)
|
| 472 |
+
home_point = gr.Textbox(
|
| 473 |
+
label="Home point (lat, lon)",
|
| 474 |
+
placeholder="31.3, 2.3",
|
| 475 |
+
)
|
| 476 |
+
dot_size = gr.Slider(
|
| 477 |
+
label="Dot size",
|
| 478 |
+
minimum=MIN_DOT_SIZE,
|
| 479 |
+
maximum=MAX_DOT_SIZE,
|
| 480 |
+
step=5,
|
| 481 |
+
value=DEFAULT_DOT_SIZE,
|
| 482 |
+
)
|
| 483 |
|
| 484 |
btn = gr.Button("Generate poster", elem_classes=["mtp-primary"])
|
| 485 |
gr.HTML(
|
|
|
|
| 492 |
|
| 493 |
btn.click(
|
| 494 |
generate,
|
| 495 |
+
inputs=[
|
| 496 |
+
city,
|
| 497 |
+
country,
|
| 498 |
+
theme,
|
| 499 |
+
distance,
|
| 500 |
+
dpi,
|
| 501 |
+
network_type,
|
| 502 |
+
dist_type,
|
| 503 |
+
home_point,
|
| 504 |
+
dot_size,
|
| 505 |
+
],
|
| 506 |
outputs=[out],
|
| 507 |
)
|
| 508 |
|
create_map_poster.py
CHANGED
|
@@ -369,6 +369,8 @@ def create_poster(
|
|
| 369 |
network_type: str = "all",
|
| 370 |
dist_type: str = "bbox",
|
| 371 |
dpi: int = 300,
|
|
|
|
|
|
|
| 372 |
) -> None:
|
| 373 |
print(f"\nGenerating map for {city}, {country}...")
|
| 374 |
theme = _require_theme()
|
|
@@ -431,6 +433,34 @@ def create_poster(
|
|
| 431 |
edge_linewidth=edge_widths,
|
| 432 |
show=False, close=False
|
| 433 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
# Layer 3: Gradients (Top and Bottom)
|
| 436 |
create_gradient_fade(ax, theme.gradient_color, location="bottom", zorder=10)
|
|
|
|
| 369 |
network_type: str = "all",
|
| 370 |
dist_type: str = "bbox",
|
| 371 |
dpi: int = 300,
|
| 372 |
+
dot: Coordinates | Sequence[float] | None = None,
|
| 373 |
+
dot_size: float = 60,
|
| 374 |
) -> None:
|
| 375 |
print(f"\nGenerating map for {city}, {country}...")
|
| 376 |
theme = _require_theme()
|
|
|
|
| 433 |
edge_linewidth=edge_widths,
|
| 434 |
show=False, close=False
|
| 435 |
)
|
| 436 |
+
|
| 437 |
+
# Optional highlight pin
|
| 438 |
+
if dot is not None:
|
| 439 |
+
pin_coords = _coerce_coordinates(dot)
|
| 440 |
+
ylim = ax.get_ylim()
|
| 441 |
+
y_range = ylim[1] - ylim[0]
|
| 442 |
+
offset = y_range * 0.0025
|
| 443 |
+
|
| 444 |
+
head_lat = pin_coords.lat + offset * 0.6
|
| 445 |
+
tip_lat = pin_coords.lat - offset * 0.6
|
| 446 |
+
|
| 447 |
+
ax.scatter(
|
| 448 |
+
[pin_coords.lon],
|
| 449 |
+
[head_lat],
|
| 450 |
+
s=float(dot_size),
|
| 451 |
+
c="#FF2D2D",
|
| 452 |
+
edgecolors="none",
|
| 453 |
+
zorder=8,
|
| 454 |
+
)
|
| 455 |
+
ax.scatter(
|
| 456 |
+
[pin_coords.lon],
|
| 457 |
+
[tip_lat],
|
| 458 |
+
s=float(dot_size) * 0.9,
|
| 459 |
+
c="#FF2D2D",
|
| 460 |
+
marker="v",
|
| 461 |
+
edgecolors="none",
|
| 462 |
+
zorder=7,
|
| 463 |
+
)
|
| 464 |
|
| 465 |
# Layer 3: Gradients (Top and Bottom)
|
| 466 |
create_gradient_fade(ax, theme.gradient_color, location="bottom", zorder=10)
|