roamify / scripts /warmup_fast.py
jofaichow's picture
v0.0.9 — Full cache sweep + adaptive radius fix
83adb51
#!/usr/bin/env python3
"""
Fast warmup — generates LLM data for missing combos only.
Skips the slow sequential image fix; get_recommendations already does parallel enrichment.
"""
import os, sys, time, json
from datetime import datetime
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), "..", ".env"), override=True)
from services.recommender import (
get_recommendations_cached,
_LLM_CACHE,
_IMAGE_CACHE,
_GEOCODE_CACHE,
)
CITIES = [
"Paris", "London", "Rome", "Barcelona", "New York", "Tokyo",
"Bangkok", "Sydney", "Cape Town", "Rio de Janeiro", "Istanbul",
"Dubai", "Seoul", "Bali", "Prague", "San Francisco", "Marrakech", "Kyoto",
]
CATEGORIES = ["Landmark", "Culture", "Nature", "Gems", "Photo", "Food", "Shopping"]
PROGRESS_FILE = os.path.join(os.path.dirname(__file__), "..", ".warmup_progress.json")
def cat_dict(cat_name: str) -> dict:
return {name: (name == cat_name) for name in CATEGORIES}
def cat_hash(cat_name: str) -> str:
return json.dumps(cat_dict(cat_name), sort_keys=True)
def load_progress() -> dict:
if not os.path.exists(PROGRESS_FILE):
return {"version": 1, "combos": {}}
try:
with open(PROGRESS_FILE) as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
return {"version": 1, "combos": {}}
def save_progress(progress: dict):
with open(PROGRESS_FILE, "w") as f:
json.dump(progress, f, indent=2)
def combo_id(city: str, cat: str) -> str:
return f"{city}::{cat}"
def is_done(progress: dict, cid: str) -> bool:
entry = progress["combos"].get(cid)
return entry and entry.get("status") == "success"
progress = load_progress()
llm_before = len(_LLM_CACHE)
# Only process combos that actually need LLM generation
todo = []
for city in CITIES:
for cat in CATEGORIES:
cid = combo_id(city, cat)
if is_done(progress, cid):
continue
key = (city, cat_hash(cat))
if key in _LLM_CACHE:
# In cache but not in progress — mark done
continue
todo.append((city, cat))
total = len(todo)
print(f"Missing combos needing API calls: {total}")
print()
for i, (city, cat) in enumerate(todo, 1):
cid = combo_id(city, cat)
print(f"[{i}/{total}] 🔍 {city} / {cat}...", end=" ", flush=True)
start = time.time()
provider_log = []
try:
result = get_recommendations_cached(
city=city, num_attractions=19,
categories=cat_dict(cat),
temperature=0,
provider_log=provider_log,
)
elapsed = time.time() - start
for entry in provider_log:
label = entry.get("provider", "?")
status = "✅" if entry.get("status") == "success" else "❌"
items = entry.get("items", 0)
dur = entry.get("elapsed", "?")
print(f"\n {label} {status} {dur}s ({items}it)", end="", flush=True)
if result:
items = len(result)
print(f"\n✅ {items} items, {elapsed:.0f}s total")
progress["combos"][cid] = {
"status": "success", "items": items,
"elapsed": round(elapsed, 1),
"provider_chain": provider_log,
"timestamp": datetime.now().isoformat(),
}
else:
print(f"\n❌ returned None, {elapsed:.0f}s total")
progress["combos"][cid] = {
"status": "failed", "elapsed": round(elapsed, 1),
"provider_chain": provider_log,
"error": "all providers returned None",
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
elapsed = time.time() - start
print(f"\n❌ {elapsed:.0f}s — {e}")
progress["combos"][cid] = {
"status": "failed", "elapsed": round(elapsed, 1),
"error": str(e), "timestamp": datetime.now().isoformat(),
}
save_progress(progress)
if i < total:
time.sleep(1.5) # Nominatim-friendly pause
# Summary
success = sum(1 for v in progress["combos"].values() if v.get("status") == "success")
failed = sum(1 for v in progress["combos"].values() if v.get("status") == "failed")
new_llm = len(_LLM_CACHE) - llm_before
print("\n" + "=" * 50)
print(f"Done! {success} success, {failed} failed, {new_llm} new cache entries")
failed_combos = [k for k,v in progress["combos"].items() if v.get("status") == "failed"]
if failed_combos:
print("Failed combos:")
for c in failed_combos:
print(f" ❌ {c.replace('::', ' / ')}")