LevyJonas's picture
Update app.py
90d5913 verified
# app.py (SpaceX-ish styling + background video + your exact logic)
import gradio as gr
from PIL import Image
# Pipeline (must support: run_search_and_generate(user_imgs=[...], user_prompt=...))
from pipeline import run_search_and_generate, load_from_hf
# --- Quick Starter file paths (must exist in your HF dataset repo) ---
QS_1_PATH = "images/LakeWater/LakeWater_000550.jpg"
QS_2_PATH = "images/DenseForest/DenseForest_000000.jpg"
QS_3_PATH = "images/ResidentialDense/ResidentialDense_001050.jpg"
# ----------------------------
# Styling
# ----------------------------
CSS = """
:root{
--bg: #070A0F;
--panel: rgba(15, 19, 29, 0.72);
--panel2: rgba(10, 12, 18, 0.55);
--line: rgba(255,255,255,0.12);
--text: rgba(255,255,255,0.94);
--muted: rgba(255,255,255,0.68);
--accent: #7DF9FF;
--accent2: #7C5CFF;
}
/* IMPORTANT: make the default Gradio background transparent so our image shows */
body, .gradio-container{
background: transparent !important;
}
/* Background image layer */
#bg-wrap{
position: fixed;
inset: 0;
z-index: -1; /* put behind everything */
pointer-events: none;
overflow: hidden;
}
#bg-image{
position: absolute;
inset: 0;
background-image:
radial-gradient(1200px 700px at 60% 25%, rgba(125,249,255,0.10), transparent 60%),
radial-gradient(900px 600px at 15% 10%, rgba(124,92,255,0.10), transparent 60%),
url("https://huggingface.co/spaces/LevyJonas/SurfaceChangePredictor/resolve/main/assets/image2.webp");
background-size: cover;
background-position: 65% 80%; /* pushes the Earth into view */
background-repeat: no-repeat;
opacity: 0.95;
filter: brightness(1.25) contrast(1.08) saturate(1.12);
}
/* Light vignette (not too dark) */
#bg-overlay{
position: absolute;
inset: 0;
background: radial-gradient(900px 700px at 50% 30%, rgba(0,0,0,0.05), rgba(0,0,0,0.35));
}
/* App container above background */
#app-wrap{
position: relative;
z-index: 1;
max-width: 1200px;
margin: 0 auto;
padding: 26px 16px 40px;
}
/* Hero card */
.hero{
border: 1px solid var(--line);
background: linear-gradient(180deg, rgba(15,19,29,0.85), rgba(10,12,18,0.55));
border-radius: 16px;
padding: 22px 22px;
box-shadow: 0 10px 40px rgba(0,0,0,0.45);
margin-bottom: 14px;
}
.hero h1{
margin: 0;
font-size: 28px;
letter-spacing: 0.6px;
}
.hero p{
margin: 8px 0 0;
color: var(--muted);
line-height: 1.45;
}
.badges{
margin-top: 12px;
display:flex;
flex-wrap:wrap;
gap:10px;
}
.badge{
border: 1px solid var(--line);
background: rgba(0,0,0,0.25);
padding: 6px 10px;
border-radius: 999px;
font-size: 12px;
color: var(--muted);
}
.badge b{ color: var(--text); }
/* Panels */
.panel{
border: 1px solid var(--line);
background: var(--panel);
border-radius: 16px;
padding: 16px;
box-shadow: 0 10px 40px rgba(0,0,0,0.35);
}
/* Inputs */
textarea, input{
background: rgba(0,0,0,0.25) !important;
color: var(--text) !important;
border: 1px solid var(--line) !important;
border-radius: 12px !important;
}
/* Buttons */
button, .gr-button{
border-radius: 12px !important;
border: 1px solid var(--line) !important;
}
#run-btn{
background: linear-gradient(90deg, rgba(125,249,255,0.95), rgba(124,92,255,0.95)) !important;
color: #05060A !important;
font-weight: 800 !important;
border: none !important;
}
#qs-row button{
background: rgba(255,255,255,0.06) !important;
color: var(--text) !important;
}
#qs-row button:hover{
border-color: rgba(125,249,255,0.45) !important;
}
/* Galleries */
.gr-gallery{
border-radius: 16px !important;
border: 1px solid var(--line) !important;
background: var(--panel2) !important;
}
"""
# ----------------------------
# Core helpers (unchanged logic)
# ----------------------------
def _files_to_pil_list(files, n_use: int):
"""Convert uploaded files to a list of PIL images (use first n_use, capped to 1..4)."""
if not files:
return []
n = max(1, min(4, int(n_use), len(files)))
imgs = []
for f in files[:n]:
path = f.name if hasattr(f, "name") else str(f)
imgs.append(Image.open(path).convert("RGB"))
return imgs
def run_app(files, n_user_imgs, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed):
try:
if not files:
return [], [], [], "Error: Please upload at least 1 image."
max_allowed = int(n_user_imgs) # slider value (1–4)
if len(files) > max_allowed:
return [], [], [], f"Error: You uploaded {len(files)} images, but the limit is {max_allowed}. Please remove extra files."
user_imgs = _files_to_pil_list(files, max_allowed)
retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
user_imgs=user_imgs,
user_prompt=prompt,
k_retrieve=k_retrieve,
n_i2i=n_i2i,
n_t2i=n_t2i,
strength_i2i=strength_i2i,
steps=steps,
gen_size=gen_size,
seed=int(seed),
)
retr_gallery = [(r["img"], f"{r['label']} | cos={r['sim']:.3f}") for r in retrieved]
i2i_gallery = [(im, f"img2img #{i+1}") for i, im in enumerate(gen_i2i)]
t2i_gallery = [(im, f"txt2img #{i+1}") for i, im in enumerate(gen_t2i)]
summary = "\n".join([f"{k}: {v}" for k, v in info.items()])
return retr_gallery, i2i_gallery, t2i_gallery, summary
except Exception as e:
return [], [], [], f"Error: {e}"
def run_quickstarter(rel_path, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed):
try:
img = load_from_hf(rel_path)
retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
user_imgs=[img],
user_prompt=prompt,
k_retrieve=k_retrieve,
n_i2i=n_i2i,
n_t2i=n_t2i,
strength_i2i=strength_i2i,
steps=steps,
gen_size=gen_size,
seed=int(seed),
)
retr_gallery = [(r["img"], f"{r['label']} | cos={r['sim']:.3f}") for r in retrieved]
i2i_gallery = [(im, f"img2img #{i+1}") for i, im in enumerate(gen_i2i)]
t2i_gallery = [(im, f"txt2img #{i+1}") for i, im in enumerate(gen_t2i)]
summary = "\n".join([f"{k}: {v}" for k, v in info.items()])
return retr_gallery, i2i_gallery, t2i_gallery, summary
except Exception as e:
return [], [], [], f"Error: {e}"
# ----------------------------
# UI (SpaceX-ish)
# ----------------------------
with gr.Blocks(title="Satellite Patch: Retrieve + Generate", css=CSS) as demo:
# OPTIONAL background image:
# If you don't have it, comment this HTML block out.
gr.HTML("""
<div id="bg-wrap">
<div id="bg-image"></div>
<div id="bg-overlay"></div>
</div>
""")
# Hero section
gr.HTML("""
<div id="app-wrap">
<div class="hero">
<h1>Satellite Patch — Retrieve + Generate</h1>
<p>
Upload up to <b>4 context patches</b> + write a prompt.
The system returns <b>Top-K similar patches</b> from the dataset and <b>new generated variants</b>
(img2img + txt2img). Built for “missing patch” exploration.
</p>
<div class="badges">
<div class="badge"><b>Embeddings:</b> DINOv2-Small</div>
<div class="badge"><b>Generator:</b> SD-Turbo</div>
<div class="badge"><b>Dataset:</b> HF Hub</div>
<div class="badge"><b>Limits:</b> 0–5 outputs</div>
</div>
</div>
</div>
""")
# Quick Starters + main layout
with gr.Row(elem_id="app-wrap"):
# Left panel (controls)
with gr.Column(scale=1, elem_classes=["panel"]):
gr.Markdown("### Quick Starters (1-click examples)")
with gr.Row(elem_id="qs-row"):
qs1 = gr.Button("LakeWater")
qs2 = gr.Button("DenseForest")
qs3 = gr.Button("ResidentialDense")
gr.Markdown("### Upload + Prompt")
files = gr.Files(
label="Upload up to 4 satellite patch images",
file_types=["image"],
file_count="multiple"
)
n_user_imgs = gr.Slider(1, 4, value=1, step=1, label="How many uploaded images to use (1–4)")
prompt = gr.Textbox(
label="Prompt (required for generation)",
value="Satellite-like RGB patch, realistic remote sensing, top-down view",
lines=2
)
gr.Markdown("### Output Controls")
k_retrieve = gr.Slider(0, 5, value=2, step=1, label="# Retrieved images (0–5)")
n_i2i = gr.Slider(0, 5, value=2, step=1, label="# img2img generated (0–5)")
n_t2i = gr.Slider(0, 5, value=2, step=1, label="# txt2img generated (0–5)")
strength_i2i = gr.Slider(0.25, 0.80, value=0.35, step=0.01, label="img2img strength (lower = closer)")
steps = gr.Slider(1, 2, value=1, step=1, label="steps (1–2)")
gen_size = gr.Radio([384, 512], value=512, label="generation size")
seed = gr.Number(value=42, precision=0, label="seed")
btn = gr.Button("Run", elem_id="run-btn")
# Right panel (outputs)
with gr.Column(scale=2, elem_classes=["panel"]):
gr.Markdown("### Results")
out_retr = gr.Gallery(label="Retrieved from Dataset", columns=5, height=260)
out_i2i = gr.Gallery(label="Generated (img2img)", columns=5, height=260)
out_t2i = gr.Gallery(label="Generated (txt2img)", columns=5, height=260)
out_txt = gr.Textbox(label="Summary", lines=8)
# Normal run
btn.click(
run_app,
inputs=[files, n_user_imgs, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
outputs=[out_retr, out_i2i, out_t2i, out_txt],
)
# Quick Starter runs (1 click)
qs1.click(
run_quickstarter,
inputs=[gr.State(QS_1_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
outputs=[out_retr, out_i2i, out_t2i, out_txt],
)
qs2.click(
run_quickstarter,
inputs=[gr.State(QS_2_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
outputs=[out_retr, out_i2i, out_t2i, out_txt],
)
qs3.click(
run_quickstarter,
inputs=[gr.State(QS_3_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
outputs=[out_retr, out_i2i, out_t2i, out_txt],
)
demo.launch()