|
|
import os |
|
|
import io |
|
|
import base64 |
|
|
import time |
|
|
import logging |
|
|
import threading |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from collections import deque |
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
from gradio_client import Client |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MAX_QUEUE_SIZE = 50 |
|
|
MAX_CONCURRENT_REQUESTS = 1 |
|
|
AVERAGE_PROCESSING_TIME = 15 |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
raise ValueError("HF_TOKEN environment variable is required") |
|
|
|
|
|
|
|
|
class QueueManager: |
|
|
def __init__(self): |
|
|
self.queue = deque() |
|
|
self.processing = {} |
|
|
self.completed = {} |
|
|
self.failed = {} |
|
|
self.lock = threading.Lock() |
|
|
self.stats = { |
|
|
'total_processed': 0, |
|
|
'total_failed': 0, |
|
|
'avg_processing_time': AVERAGE_PROCESSING_TIME |
|
|
} |
|
|
|
|
|
def add_request(self, request_id: str, user_data: dict) -> Tuple[int, float]: |
|
|
"""Add request to queue. Returns (position, estimated_wait)""" |
|
|
with self.lock: |
|
|
if len(self.queue) >= MAX_QUEUE_SIZE: |
|
|
raise Exception("Queue is full. Please try again later.") |
|
|
|
|
|
self.queue.append((request_id, user_data, time.time())) |
|
|
position = len(self.queue) |
|
|
|
|
|
|
|
|
processing_count = len(self.processing) |
|
|
queue_ahead = position - 1 |
|
|
|
|
|
if processing_count == 0: |
|
|
estimated_wait = 0 |
|
|
else: |
|
|
estimated_wait = (queue_ahead + 1) * self.stats['avg_processing_time'] |
|
|
|
|
|
logger.info(f"Request {request_id} added to queue. Position: {position}, Est. wait: {estimated_wait:.0f}s") |
|
|
return position, estimated_wait |
|
|
|
|
|
def get_next_requests(self): |
|
|
"""Get next request to process (only 1 at a time for GPU)""" |
|
|
with self.lock: |
|
|
if len(self.processing) >= MAX_CONCURRENT_REQUESTS or len(self.queue) == 0: |
|
|
return [] |
|
|
|
|
|
request_id, user_data, timestamp = self.queue.popleft() |
|
|
self.processing[request_id] = time.time() |
|
|
return [(request_id, user_data)] |
|
|
|
|
|
def complete_request(self, request_id: str, result): |
|
|
"""Mark request as completed""" |
|
|
with self.lock: |
|
|
if request_id in self.processing: |
|
|
processing_time = time.time() - self.processing[request_id] |
|
|
del self.processing[request_id] |
|
|
self.completed[request_id] = result |
|
|
self.stats['total_processed'] += 1 |
|
|
|
|
|
|
|
|
current_avg = self.stats['avg_processing_time'] |
|
|
self.stats['avg_processing_time'] = (current_avg * 0.8) + (processing_time * 0.2) |
|
|
|
|
|
logger.info(f"Request {request_id} completed in {processing_time:.1f}s") |
|
|
|
|
|
def fail_request(self, request_id: str, error_msg: str): |
|
|
"""Mark request as failed""" |
|
|
with self.lock: |
|
|
if request_id in self.processing: |
|
|
del self.processing[request_id] |
|
|
self.failed[request_id] = error_msg |
|
|
self.stats['total_failed'] += 1 |
|
|
logger.error(f"Request {request_id} failed: {error_msg}") |
|
|
|
|
|
def get_request_status(self, request_id: str) -> dict: |
|
|
"""Get status of specific request""" |
|
|
with self.lock: |
|
|
if request_id in self.completed: |
|
|
return {'status': 'completed', 'result': self.completed[request_id]} |
|
|
elif request_id in self.failed: |
|
|
return {'status': 'failed', 'error': self.failed[request_id]} |
|
|
elif request_id in self.processing: |
|
|
processing_time = time.time() - self.processing[request_id] |
|
|
return {'status': 'processing', 'time': processing_time} |
|
|
else: |
|
|
for i, (rid, _, _) in enumerate(self.queue): |
|
|
if rid == request_id: |
|
|
return {'status': 'queued', 'position': i + 1} |
|
|
return {'status': 'not_found'} |
|
|
|
|
|
|
|
|
queue_manager = QueueManager() |
|
|
|
|
|
backend_status = { |
|
|
"client": None, |
|
|
"connected": False, |
|
|
"last_check": None, |
|
|
"error_message": "" |
|
|
} |
|
|
|
|
|
def check_backend_connection(): |
|
|
"""Ping the HF Space and cache the client object.""" |
|
|
try: |
|
|
test_client = Client("milliyin/backend", hf_token=HF_TOKEN) |
|
|
backend_status.update({ |
|
|
"client": test_client, |
|
|
"connected": True, |
|
|
"error_message": "", |
|
|
"last_check": time.time(), |
|
|
}) |
|
|
logger.info("β
Backend connection established") |
|
|
return True, "π’ Model is ready" |
|
|
except Exception as e: |
|
|
backend_status.update({ |
|
|
"client": None, |
|
|
"connected": False, |
|
|
"last_check": time.time(), |
|
|
"error_message": str(e), |
|
|
}) |
|
|
err = str(e).lower() |
|
|
if "timeout" in err or "read operation timed out" in err: |
|
|
return False, "π‘ Model is starting up. Please wait 3β4 min." |
|
|
return False, f"π΄ Backend error: {e}" |
|
|
|
|
|
|
|
|
check_backend_connection() |
|
|
|
|
|
|
|
|
def queue_worker(): |
|
|
"""Background worker to process queue - one request at a time""" |
|
|
while True: |
|
|
try: |
|
|
requests = queue_manager.get_next_requests() |
|
|
|
|
|
if not requests: |
|
|
time.sleep(1) |
|
|
continue |
|
|
|
|
|
|
|
|
request_id, user_data = requests[0] |
|
|
logger.info(f"Starting processing request {request_id}") |
|
|
|
|
|
process_single_request(request_id, user_data) |
|
|
time.sleep(0.5) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Queue worker error: {e}") |
|
|
time.sleep(5) |
|
|
|
|
|
def process_single_request(request_id: str, user_data: dict): |
|
|
"""Process a single request""" |
|
|
try: |
|
|
img_b64 = user_data['image_b64'] |
|
|
category = user_data['category'] |
|
|
gender = user_data['gender'] |
|
|
|
|
|
if not backend_status["connected"]: |
|
|
check_backend_connection() |
|
|
if not backend_status["connected"]: |
|
|
raise Exception("Backend not available") |
|
|
|
|
|
client = backend_status["client"] |
|
|
start_time = time.time() |
|
|
|
|
|
result = client.predict( |
|
|
img_b64, |
|
|
category, |
|
|
gender, |
|
|
api_name="/predict", |
|
|
) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
if not result or len(result) < 4: |
|
|
raise ValueError("Invalid response structure from backend") |
|
|
|
|
|
_, overlay_b64, bg_b64, status = result |
|
|
|
|
|
final_result = { |
|
|
'overlay_b64': overlay_b64, |
|
|
'bg_b64': bg_b64, |
|
|
'status': status, |
|
|
'processing_time': processing_time |
|
|
} |
|
|
|
|
|
queue_manager.complete_request(request_id, final_result) |
|
|
|
|
|
except Exception as e: |
|
|
queue_manager.fail_request(request_id, str(e)) |
|
|
|
|
|
|
|
|
worker_thread = threading.Thread(target=queue_worker, daemon=True) |
|
|
worker_thread.start() |
|
|
|
|
|
|
|
|
def image_to_base64(image: Image.Image) -> str: |
|
|
if image is None: |
|
|
return "" |
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
buf = io.BytesIO() |
|
|
image.save(buf, format="PNG") |
|
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
def base64_to_image(b64: str) -> Optional[Image.Image]: |
|
|
if not b64: |
|
|
return None |
|
|
try: |
|
|
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode base64 β image: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
active_requests = {} |
|
|
|
|
|
def submit_request(input_image: Image.Image, category: str, gender: str): |
|
|
"""Submit a new request to the queue""" |
|
|
if input_image is None: |
|
|
return None, None, "β Please upload an image.", gr.update(interactive=True), "" |
|
|
|
|
|
try: |
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
img_b64 = image_to_base64(input_image) |
|
|
user_data = { |
|
|
'image_b64': img_b64, |
|
|
'category': category, |
|
|
'gender': gender, |
|
|
'timestamp': time.time() |
|
|
} |
|
|
|
|
|
position, estimated_wait = queue_manager.add_request(request_id, user_data) |
|
|
|
|
|
status_msg = f"π Request submitted! Position in queue: #{position}" |
|
|
if position == 1 and len(queue_manager.processing) == 0: |
|
|
status_msg += " | Starting processing now..." |
|
|
elif estimated_wait > 0: |
|
|
status_msg += f" | Estimated wait: {estimated_wait:.0f}s" |
|
|
|
|
|
return None, None, status_msg, gr.update(interactive=False), request_id |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"β {str(e)}", gr.update(interactive=True), "" |
|
|
|
|
|
def check_request_status(request_id: str): |
|
|
"""Check the status of a request""" |
|
|
if not request_id: |
|
|
return None, None, "No active request", gr.update(interactive=True) |
|
|
|
|
|
status_info = queue_manager.get_request_status(request_id) |
|
|
|
|
|
if status_info['status'] == 'completed': |
|
|
result = status_info['result'] |
|
|
overlay_img = base64_to_image(result['overlay_b64']) |
|
|
bg_img = base64_to_image(result['bg_b64']) |
|
|
status_msg = f"β
{result['status']} (β± {result['processing_time']:.1f}s)" |
|
|
return overlay_img, bg_img, status_msg, gr.update(interactive=True) |
|
|
|
|
|
elif status_info['status'] == 'failed': |
|
|
return None, None, f"β {status_info['error']}", gr.update(interactive=True) |
|
|
|
|
|
elif status_info['status'] == 'processing': |
|
|
processing_time = status_info['time'] |
|
|
return None, None, f"β‘ Processing... ({processing_time:.1f}s)", gr.update(interactive=False) |
|
|
|
|
|
elif status_info['status'] == 'queued': |
|
|
position = status_info['position'] |
|
|
avg_time = queue_manager.stats['avg_processing_time'] |
|
|
estimated_wait = position * avg_time |
|
|
wait_msg = f" | Est. wait: {int(estimated_wait/60)}m {int(estimated_wait%60)}s" if estimated_wait > 30 else "" |
|
|
return None, None, f"β³ In queue, position #{position}{wait_msg}", gr.update(interactive=False) |
|
|
|
|
|
else: |
|
|
return None, None, "β Request not found", gr.update(interactive=True) |
|
|
|
|
|
def disable_button(): |
|
|
return gr.update(interactive=False) |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
background: linear-gradient(135deg, #3b4371 0%, #2d1b69 25%, #673ab7 50%, #8e24aa 75%, #6a1b9a 100%); |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
min-height: 100vh; |
|
|
} |
|
|
.contain { |
|
|
background: rgba(255, 255, 255, 0.95); |
|
|
border-radius: 15px; |
|
|
padding: 25px; |
|
|
margin: 15px; |
|
|
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2); |
|
|
backdrop-filter: blur(10px); |
|
|
} |
|
|
.title-container { |
|
|
text-align: center; |
|
|
margin-bottom: 25px; |
|
|
padding: 20px; |
|
|
background: linear-gradient(135deg, #673ab7, #8e24aa); |
|
|
border-radius: 12px; |
|
|
box-shadow: 0 5px 20px rgba(103, 58, 183, 0.4); |
|
|
} |
|
|
.title-container h1 { |
|
|
color: white; |
|
|
font-size: 2.2em; |
|
|
font-weight: bold; |
|
|
margin: 0; |
|
|
text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.3); |
|
|
} |
|
|
.info-bar { |
|
|
background: linear-gradient(135deg, #7c4dff, #6a1b9a); |
|
|
padding: 12px; |
|
|
border-radius: 8px; |
|
|
margin-bottom: 20px; |
|
|
color: white; |
|
|
text-align: center; |
|
|
font-weight: 500; |
|
|
box-shadow: 0 3px 12px rgba(124, 77, 255, 0.3); |
|
|
} |
|
|
.section-header { |
|
|
background: linear-gradient(135deg, #e1bee7, #d1c4e9); |
|
|
padding: 12px; |
|
|
border-radius: 8px; |
|
|
margin-bottom: 15px; |
|
|
border-left: 4px solid #673ab7; |
|
|
} |
|
|
.section-header h3 { |
|
|
margin: 0; |
|
|
color: #333; |
|
|
font-weight: 600; |
|
|
} |
|
|
.input-group { |
|
|
background: rgba(255, 255, 255, 0.85); |
|
|
padding: 18px; |
|
|
border-radius: 12px; |
|
|
margin-bottom: 15px; |
|
|
border: 1px solid rgba(103, 58, 183, 0.2); |
|
|
box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); |
|
|
} |
|
|
.result-section { |
|
|
background: rgba(255, 255, 255, 0.9); |
|
|
padding: 18px; |
|
|
border-radius: 12px; |
|
|
border: 1px solid rgba(103, 58, 183, 0.2); |
|
|
box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); |
|
|
} |
|
|
.tip-box { |
|
|
background: linear-gradient(135deg, #f3e5f5, #e8eaf6); |
|
|
padding: 10px; |
|
|
border-radius: 6px; |
|
|
margin: 8px 0; |
|
|
border-left: 3px solid #673ab7; |
|
|
color: #4a148c; |
|
|
font-weight: 500; |
|
|
} |
|
|
button.primary { |
|
|
background: linear-gradient(135deg, #673ab7, #8e24aa) !important; |
|
|
border: none !important; |
|
|
border-radius: 20px !important; |
|
|
padding: 12px 25px !important; |
|
|
color: white !important; |
|
|
font-weight: bold !important; |
|
|
font-size: 15px !important; |
|
|
box-shadow: 0 5px 15px rgba(103, 58, 183, 0.4) !important; |
|
|
} |
|
|
button.primary:hover { |
|
|
box-shadow: 0 8px 25px rgba(103, 58, 183, 0.6) !important; |
|
|
opacity: 0.9 !important; |
|
|
transform: translateY(-2px) !important; |
|
|
} |
|
|
label { |
|
|
color: #4a148c !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
input, textarea, select { |
|
|
border: 1px solid rgba(103, 58, 183, 0.3) !important; |
|
|
border-radius: 6px !important; |
|
|
} |
|
|
input:focus, textarea:focus, select:focus { |
|
|
border-color: #673ab7 !important; |
|
|
box-shadow: 0 0 0 2px rgba(103, 58, 183, 0.2) !important; |
|
|
} |
|
|
.gr-slider input[type="range"] { |
|
|
accent-color: #673ab7 !important; |
|
|
} |
|
|
input[type="checkbox"] { |
|
|
accent-color: #673ab7 !important; |
|
|
} |
|
|
.preserve-aspect-ratio img { |
|
|
object-fit: contain !important; |
|
|
width: auto !important; |
|
|
max-height: 512px !important; |
|
|
} |
|
|
.social-links { |
|
|
text-align: center; |
|
|
margin: 20px 0; |
|
|
} |
|
|
.social-links a { |
|
|
margin: 0 10px; |
|
|
padding: 8px 16px; |
|
|
background: #667eea; |
|
|
color: white; |
|
|
text-decoration: none; |
|
|
border-radius: 8px; |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
.social-links a:hover { |
|
|
background: #764ba2; |
|
|
transform: translateY(-2px); |
|
|
} |
|
|
.feature-box { |
|
|
background: #f8fafc; |
|
|
border: 1px solid #e2e8f0; |
|
|
padding: 20px; |
|
|
border-radius: 12px; |
|
|
margin: 10px 0; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="Jewellery Photography Preview") as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
|
<h1 style="font-size: 2.5em;">π¨ Raresence: AI-Powered Jewellery Photo Preview</h1> |
|
|
<p style="color: #666;">Upload a jewellery image, select model, and get professional photos instantly</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
status_html = gr.HTML() |
|
|
|
|
|
def _update_status(): |
|
|
ok, msg = check_backend_connection() |
|
|
cls = "status-ready" if ok else ("status-starting" if "π‘" in msg else "status-error") |
|
|
return f'<div class="status-banner {cls}">{msg}</div>' |
|
|
|
|
|
status_html.value = _update_status() |
|
|
gr.Button("π Check Status").click(fn=_update_status, outputs=status_html) |
|
|
|
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=0.4): |
|
|
gr.HTML(""" |
|
|
<div class="feature-box""> |
|
|
<h3>πΌοΈ Upload Jewellery Image</h3> |
|
|
<p style="color: #666; font-size: 14px;">Select a clear jewellery image for best results</p> |
|
|
</div> |
|
|
""") |
|
|
gr.Markdown("β") |
|
|
gr.Markdown("β") |
|
|
input_img = gr.Image(label="Upload image", type="pil", height=400) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.HTML(""" |
|
|
<div class="feature-box"> |
|
|
<h3>π¨ AI Generated Results</h3> |
|
|
<p style="color: #666; font-size: 14px;">Preview overlay detection and final professional background</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Final result"): |
|
|
info2 = gr.Markdown(value="### Final result") |
|
|
out_bg = gr.Image(height=400) |
|
|
with gr.TabItem("Detection overlay"): |
|
|
info1 = gr.Markdown(value="### Detection overlay") |
|
|
out_overlay = gr.Image(height=400) |
|
|
run_btn = gr.Button("π― Generate", elem_id="button", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=0.4): |
|
|
gr.Markdown(value="Setting") |
|
|
category = gr.Dropdown(label="Jewellery category", choices=["Rings", "Bracelets", "Watches", "Earrings"], value="Bracelets") |
|
|
gender = gr.Dropdown(label="Model gender", choices=["male", "female"], value="female") |
|
|
|
|
|
|
|
|
out_status = gr.Text(label="Status", interactive=False) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align:center;padding:40px 20px;background:#f8fafc;border:1px solid #e2e8f0;border-radius:16px;margin:30px 0;"> |
|
|
<h3 style="color:#333;">π Powered by Snapwear AI</h3> |
|
|
<p style="color:#666;"> |
|
|
Experience the future of virtual fashion and garment visualization. |
|
|
</p> |
|
|
<div class="social-links"> |
|
|
<a href="https://snapwear.io" target="_blank">π Website</a> |
|
|
<a href="https://www.instagram.com/snapwearai/" target="_blank">πΈ Instagram</a> |
|
|
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Texture-Transfer" target="_blank">π¨ Pattern Transfer</a> |
|
|
</div> |
|
|
<p style="font-size:12px;color:#999;margin-top:20px;"> |
|
|
Β© 2024 Snapwear AI. Professional AI tools for fashion and design. |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
current_request_id = gr.State("") |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=disable_button, |
|
|
inputs=None, |
|
|
outputs=run_btn |
|
|
).then( |
|
|
fn=submit_request, |
|
|
inputs=[input_img, category, gender], |
|
|
outputs=[out_overlay, out_bg, out_status, run_btn, current_request_id], |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
|
|
|
def auto_status_check(request_id): |
|
|
if request_id: |
|
|
return check_request_status(request_id) |
|
|
return None, None, "Ready to generate", gr.update(interactive=True) |
|
|
|
|
|
|
|
|
demo.load(lambda: None) |
|
|
|
|
|
|
|
|
timer = gr.Timer(2) |
|
|
timer.tick( |
|
|
fn=auto_status_check, |
|
|
inputs=[current_request_id], |
|
|
outputs=[out_overlay, out_bg, out_status, run_btn] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=MAX_QUEUE_SIZE + 10, default_concurrency_limit=1).launch(share=False) |