milliyin commited on
Commit
0b6319d
Β·
verified Β·
1 Parent(s): 82d6b64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import time
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ from gradio_client import Client
10
+ from PIL import Image
11
+
12
+ # ───────── Logging ─────────
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # ───────── Constants ─────────
17
+ PREDICT_TIMEOUT = 600 # ⏱ increase backend timeout (seconds)
18
+ GPU_WARM_WINDOW = 900 # πŸ•’ keep GPU alive 15β€―min (must match backend)
19
+ MAX_RETRIES = 1 # πŸ”„ avoid multiple backend calls per click
20
+
21
+ # ───────── Backend connection ─────────
22
+ # HF_TOKEN = os.getenv("HF_TOKEN")
23
+ # if not HF_TOKEN:
24
+ # raise ValueError("HF_TOKEN environment variable is required")
25
+ #-------
26
+
27
+ backend_status = {
28
+ "client": None,
29
+ "connected": False,
30
+ "last_check": None,
31
+ "error_message": ""
32
+ }
33
+
34
+
35
+ def check_backend_connection():
36
+ """Ping the HF Space and cache the client object."""
37
+ try:
38
+ test_client = Client("SnapwearAI/SAKS-backend", hf_token=HF_TOKEN)
39
+ backend_status.update(
40
+ {
41
+ "client": test_client,
42
+ "connected": True,
43
+ "error_message": "",
44
+ "last_check": time.time(),
45
+ }
46
+ )
47
+ logger.info("βœ… Backend connection established")
48
+ return True, "🟒 Model is ready"
49
+ except Exception as e:
50
+ backend_status.update(
51
+ {
52
+ "client": None,
53
+ "connected": False,
54
+ "last_check": time.time(),
55
+ "error_message": str(e),
56
+ }
57
+ )
58
+ err = str(e).lower()
59
+ if "timeout" in err or "read operation timed out" in err:
60
+ return False, "🟑 Model is starting up. Please wait 3‑4β€―min."
61
+ return False, f"πŸ”΄ Backend error: {e}"
62
+
63
+
64
+ # initial probe
65
+ # check_backend_connection()
66
+
67
+ # ───────── Helpers ─────────
68
+
69
+
70
+ def image_to_base64(image: Image.Image) -> str:
71
+ if image is None:
72
+ return ""
73
+ if image.mode != "RGB":
74
+ image = image.convert("RGB")
75
+ buf = io.BytesIO()
76
+ image.save(buf, format="PNG")
77
+ return base64.b64encode(buf.getvalue()).decode()
78
+
79
+
80
+ def base64_to_image(b64: str) -> Image.Image | None:
81
+ if not b64:
82
+ return None
83
+ try:
84
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
85
+ except Exception as e:
86
+ logger.error(f"Failed to decode base64 β†’ image: {e}")
87
+ return None
88
+
89
+
90
+ # ───────── UI ↔ Backend bridge ─────────
91
+
92
+
93
+ def call_backend_with_retry(input_image: Image.Image, category: str, gender: str, *, max_retries: int = MAX_RETRIES):
94
+ """Single‑shot call (no more than `max_retries` times)."""
95
+
96
+ if input_image is None:
97
+ return None, None, "❌ Please upload an image."
98
+
99
+ # lazy reconnect
100
+ if not backend_status["connected"]:
101
+ ok, msg = check_backend_connection()
102
+ if not ok:
103
+ return None, None, msg
104
+
105
+ client: Client = backend_status["client"]
106
+ img_b64 = image_to_base64(input_image)
107
+
108
+ for attempt in range(max_retries):
109
+ try:
110
+ logger.info(f"Backend callΒ #{attempt+1}")
111
+ start = time.time()
112
+ result = client.predict(
113
+ img_b64,
114
+ category,
115
+ gender,
116
+ api_name="/predict",
117
+ )
118
+ dt = time.time() - start
119
+
120
+ if not result or len(result) < 4:
121
+ raise ValueError("Invalid response structure from backend")
122
+
123
+ _, overlay_b64, bg_b64, status = result
124
+ overlay_img = base64_to_image(overlay_b64)
125
+ bg_img = base64_to_image(bg_b64)
126
+
127
+ if overlay_img is None or bg_img is None:
128
+ raise ValueError("Failed to decode backend images")
129
+
130
+ if not status.startswith("βœ…"):
131
+ status = "βœ… " + status
132
+ status += f" (⏱ {dt:.1f}s)"
133
+ return overlay_img, bg_img, status
134
+
135
+ except Exception as e:
136
+ logger.error(f"AttemptΒ {attempt+1} failed: {e}")
137
+ if attempt == max_retries - 1:
138
+ return None, None, f"❌ {e}"
139
+ time.sleep(1)
140
+
141
+ return None, None, "❌ Unknown error"
142
+
143
+
144
+ # ───────── CSS ─────────
145
+ CSS_PATH = Path(__file__).with_name("style.css")
146
+ CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else "" # external CSS if you prefer
147
+
148
+
149
+ # ───────── Gradio Blocks ─────────
150
+ with gr.Blocks(css=CUSTOM_CSS, title="SAKS Product Photography Demo") as demo:
151
+ # Hero
152
+ gr.HTML(
153
+ """
154
+ <div class="hero-section">
155
+ <h1 style="font-size:48px;margin:0;background:linear-gradient(45deg,#fff,#f0f8ff);-webkit-background-clip:text;-webkit-text-fill-color:transparent;">πŸ’Β SAKSΒ ProductΒ Photography</h1>
156
+ <p style="font-size:18px;margin:10px 0;opacity:0.8;">Detect jewellery β†’ segment β†’ generate pro model photo</p>
157
+ </div>
158
+ """
159
+ )
160
+
161
+ # Status banner
162
+ status_html = gr.HTML()
163
+
164
+ def _update_status():
165
+ ok, msg = check_backend_connection()
166
+ cls = "status-ready" if ok else ("status-starting" if "🟑" in msg else "status-error")
167
+ return f'<div class="status-banner {cls}">{msg}</div>'
168
+
169
+ status_html.value = _update_status()
170
+ gr.Button("πŸ”„ CheckΒ Status").click(fn=_update_status, outputs=status_html)
171
+
172
+ # 2‑column layout
173
+ with gr.Row():
174
+ # left inputs
175
+ with gr.Column():
176
+ input_img = gr.Image(label="Upload image", type="pil", height=400)
177
+ category = gr.Dropdown(label="Jewellery category", choices=["Rings", "Bracelets", "Watches", "Earrings"], value="Bracelets")
178
+ gender = gr.Dropdown(label="Model gender", choices=["male", "female"], value="female")
179
+ run_btn = gr.Button("🎯 Generate", elem_id="button")
180
+
181
+ # right outputs
182
+ with gr.Column():
183
+ with gr.Tabs():
184
+ with gr.TabItem("Detection overlay"):
185
+ out_overlay = gr.Image(height=400)
186
+ with gr.TabItem("Final result"):
187
+ out_bg = gr.Image(height=400)
188
+ out_status = gr.Text(label="Status", interactive=False)
189
+
190
+ # wire button β†’ backend
191
+ run_btn.click(
192
+ fn=call_backend_with_retry,
193
+ inputs=[input_img, category, gender],
194
+ outputs=[out_overlay, out_bg, out_status],
195
+ concurrency_limit=1,
196
+ show_progress=True,
197
+ )
198
+
199
+
200
+ # ───────── Launch ─────────
201
+ if __name__ == "__main__":
202
+ demo.queue(max_size=20, default_concurrency_limit=1).launch(share=False)