nemoooooooooo commited on
Commit
43376d9
Β·
1 Parent(s): bbb8f23
Files changed (2) hide show
  1. app.py +202 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import time
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ from gradio_client import Client
10
+ from PIL import Image
11
+
12
+ # ───────── Logging ─────────
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # ───────── Constants ─────────
17
+ PREDICT_TIMEOUT = 600 # ⏱ increase backend timeout (seconds)
18
+ GPU_WARM_WINDOW = 900 # πŸ•’ keep GPU alive 15β€―min (must match backend)
19
+ MAX_RETRIES = 1 # πŸ”„ avoid multiple backend calls per click
20
+
21
+ # ───────── Backend connection ─────────
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
23
+ if not HF_TOKEN:
24
+ raise ValueError("HF_TOKEN environment variable is required")
25
+
26
+ backend_status = {
27
+ "client": None,
28
+ "connected": False,
29
+ "last_check": None,
30
+ "error_message": ""
31
+ }
32
+
33
+
34
+ def check_backend_connection():
35
+ """Ping the HF Space and cache the client object."""
36
+ try:
37
+ test_client = Client("SnapwearAI/SAKS-backend", hf_token=HF_TOKEN)
38
+ backend_status.update(
39
+ {
40
+ "client": test_client,
41
+ "connected": True,
42
+ "error_message": "",
43
+ "last_check": time.time(),
44
+ }
45
+ )
46
+ logger.info("βœ… Backend connection established")
47
+ return True, "🟒 Model is ready"
48
+ except Exception as e:
49
+ backend_status.update(
50
+ {
51
+ "client": None,
52
+ "connected": False,
53
+ "last_check": time.time(),
54
+ "error_message": str(e),
55
+ }
56
+ )
57
+ err = str(e).lower()
58
+ if "timeout" in err or "read operation timed out" in err:
59
+ return False, "🟑 Model is starting up. Please wait 3‑4β€―min."
60
+ return False, f"πŸ”΄ Backend error: {e}"
61
+
62
+
63
+ # initial probe
64
+ check_backend_connection()
65
+
66
+ # ───────── Helpers ─────────
67
+
68
+
69
+ def image_to_base64(image: Image.Image) -> str:
70
+ if image is None:
71
+ return ""
72
+ if image.mode != "RGB":
73
+ image = image.convert("RGB")
74
+ buf = io.BytesIO()
75
+ image.save(buf, format="PNG")
76
+ return base64.b64encode(buf.getvalue()).decode()
77
+
78
+
79
+ def base64_to_image(b64: str) -> Image.Image | None:
80
+ if not b64:
81
+ return None
82
+ try:
83
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
84
+ except Exception as e:
85
+ logger.error(f"Failed to decode base64 β†’ image: {e}")
86
+ return None
87
+
88
+
89
+ # ───────── UI ↔ Backend bridge ─────────
90
+
91
+
92
+ def call_backend_with_retry(input_image: Image.Image, category: str, gender: str, *, max_retries: int = MAX_RETRIES):
93
+ """Single‑shot call (no more than `max_retries` times)."""
94
+
95
+ if input_image is None:
96
+ return None, None, "❌ Please upload an image."
97
+
98
+ # lazy reconnect
99
+ if not backend_status["connected"]:
100
+ ok, msg = check_backend_connection()
101
+ if not ok:
102
+ return None, None, msg
103
+
104
+ client: Client = backend_status["client"]
105
+ img_b64 = image_to_base64(input_image)
106
+
107
+ for attempt in range(max_retries):
108
+ try:
109
+ logger.info(f"Backend call #{attempt+1}")
110
+ start = time.time()
111
+ result = client.predict(
112
+ img_b64,
113
+ category,
114
+ gender,
115
+ api_name="/predict",
116
+ )
117
+ dt = time.time() - start
118
+
119
+ if not result or len(result) < 4:
120
+ raise ValueError("Invalid response structure from backend")
121
+
122
+ _, overlay_b64, bg_b64, status = result
123
+ overlay_img = base64_to_image(overlay_b64)
124
+ bg_img = base64_to_image(bg_b64)
125
+
126
+ if overlay_img is None or bg_img is None:
127
+ raise ValueError("Failed to decode backend images")
128
+
129
+ if not status.startswith("βœ…"):
130
+ status = "βœ… " + status
131
+ status += f" (⏱ {dt:.1f}s)"
132
+ return overlay_img, bg_img, status
133
+
134
+ except Exception as e:
135
+ logger.error(f"Attempt {attempt+1} failed: {e}")
136
+ if attempt == max_retries - 1:
137
+ return None, None, f"❌ {e}"
138
+ time.sleep(1)
139
+
140
+ return None, None, "❌ Unknown error"
141
+
142
+
143
+ # ───────── CSS ─────────
144
+ CSS_PATH = Path(__file__).with_name("style.css")
145
+ CUSTOM_CSS = CSS_PATH.read_text() if CSS_PATH.exists() else "" # external CSS if you prefer
146
+
147
+
148
+ # ───────── Gradio Blocks ─────────
149
+ with gr.Blocks(css=CUSTOM_CSS, title="SAKS Product Photography Demo") as demo:
150
+ # Hero
151
+ gr.HTML(
152
+ """
153
+ <div class="hero-section">
154
+ <h1 style="font-size:48px;margin:0;background:linear-gradient(45deg,#fff,#f0f8ff);-webkit-background-clip:text;-webkit-text-fill-color:transparent;">πŸ’ SAKS Product Photography</h1>
155
+ <p style="font-size:18px;margin:10px 0;opacity:0.8;">Detect jewellery β†’ segment β†’ generate pro model photo</p>
156
+ </div>
157
+ """
158
+ )
159
+
160
+ # Status banner
161
+ status_html = gr.HTML()
162
+
163
+ def _update_status():
164
+ ok, msg = check_backend_connection()
165
+ cls = "status-ready" if ok else ("status-starting" if "🟑" in msg else "status-error")
166
+ return f'<div class="status-banner {cls}">{msg}</div>'
167
+
168
+ status_html.value = _update_status()
169
+ gr.Button("πŸ”„ Check Status").click(fn=_update_status, outputs=status_html)
170
+
171
+ # 2‑column layout
172
+ with gr.Row():
173
+ # left inputs
174
+ with gr.Column():
175
+ input_img = gr.Image(label="Upload image", type="pil", height=400)
176
+ category = gr.Dropdown(label="Jewellery category", choices=["Rings", "Bracelets", "Watches", "Earrings"], value="Bracelets")
177
+ gender = gr.Dropdown(label="Model gender", choices=["male", "female"], value="female")
178
+ run_btn = gr.Button("🎯 Generate", elem_id="button")
179
+
180
+ # right outputs
181
+ with gr.Column():
182
+ with gr.Tabs():
183
+ with gr.TabItem("Detection overlay"):
184
+ out_overlay = gr.Image(height=400)
185
+ with gr.TabItem("Final result"):
186
+ out_bg = gr.Image(height=400)
187
+ out_status = gr.Text(label="Status", interactive=False)
188
+
189
+ # wire button β†’ backend
190
+ run_btn.click(
191
+ fn=call_backend_with_retry,
192
+ inputs=[input_img, category, gender],
193
+ outputs=[out_overlay, out_bg, out_status],
194
+ concurrency_limit=1,
195
+ show_progress=True,
196
+ )
197
+
198
+
199
+ # ───────── Launch ─────────
200
+ if __name__ == "__main__":
201
+ demo.queue(max_size=20, default_concurrency_limit=1).launch(server_name="0.0.0.0", server_port=7860, share=False)
202
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ gradio-client
3
+ Pillow
4
+ numpy