myspace / app.py
milliyin's picture
Update app.py
3e216be verified
import os
import io
import base64
import time
import logging
import threading
import uuid
from datetime import datetime
from pathlib import Path
from collections import deque
from typing import Dict, Optional, Tuple
import gradio as gr
from gradio_client import Client
from PIL import Image
# ───────── Logging ─────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ───────── Queue System Configuration ─────────
MAX_QUEUE_SIZE = 50
MAX_CONCURRENT_REQUESTS = 1 # GPU can only handle 1 request at a time
AVERAGE_PROCESSING_TIME = 15 # seconds
# ───────── Backend connection ─────────
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is required")
# ───────── Global Queue System ─────────
class QueueManager:
def __init__(self):
self.queue = deque() # (request_id, user_data, timestamp)
self.processing = {} # request_id -> processing_start_time
self.completed = {} # request_id -> result
self.failed = {} # request_id -> error_message
self.lock = threading.Lock()
self.stats = {
'total_processed': 0,
'total_failed': 0,
'avg_processing_time': AVERAGE_PROCESSING_TIME
}
def add_request(self, request_id: str, user_data: dict) -> Tuple[int, float]:
"""Add request to queue. Returns (position, estimated_wait)"""
with self.lock:
if len(self.queue) >= MAX_QUEUE_SIZE:
raise Exception("Queue is full. Please try again later.")
self.queue.append((request_id, user_data, time.time()))
position = len(self.queue)
# Calculate estimated wait time for single GPU
processing_count = len(self.processing)
queue_ahead = position - 1
if processing_count == 0:
estimated_wait = 0
else:
estimated_wait = (queue_ahead + 1) * self.stats['avg_processing_time']
logger.info(f"Request {request_id} added to queue. Position: {position}, Est. wait: {estimated_wait:.0f}s")
return position, estimated_wait
def get_next_requests(self):
"""Get next request to process (only 1 at a time for GPU)"""
with self.lock:
if len(self.processing) >= MAX_CONCURRENT_REQUESTS or len(self.queue) == 0:
return []
request_id, user_data, timestamp = self.queue.popleft()
self.processing[request_id] = time.time()
return [(request_id, user_data)]
def complete_request(self, request_id: str, result):
"""Mark request as completed"""
with self.lock:
if request_id in self.processing:
processing_time = time.time() - self.processing[request_id]
del self.processing[request_id]
self.completed[request_id] = result
self.stats['total_processed'] += 1
# Update average processing time
current_avg = self.stats['avg_processing_time']
self.stats['avg_processing_time'] = (current_avg * 0.8) + (processing_time * 0.2)
logger.info(f"Request {request_id} completed in {processing_time:.1f}s")
def fail_request(self, request_id: str, error_msg: str):
"""Mark request as failed"""
with self.lock:
if request_id in self.processing:
del self.processing[request_id]
self.failed[request_id] = error_msg
self.stats['total_failed'] += 1
logger.error(f"Request {request_id} failed: {error_msg}")
def get_request_status(self, request_id: str) -> dict:
"""Get status of specific request"""
with self.lock:
if request_id in self.completed:
return {'status': 'completed', 'result': self.completed[request_id]}
elif request_id in self.failed:
return {'status': 'failed', 'error': self.failed[request_id]}
elif request_id in self.processing:
processing_time = time.time() - self.processing[request_id]
return {'status': 'processing', 'time': processing_time}
else:
for i, (rid, _, _) in enumerate(self.queue):
if rid == request_id:
return {'status': 'queued', 'position': i + 1}
return {'status': 'not_found'}
# Global queue manager
queue_manager = QueueManager()
backend_status = {
"client": None,
"connected": False,
"last_check": None,
"error_message": ""
}
def check_backend_connection():
"""Ping the HF Space and cache the client object."""
try:
test_client = Client("milliyin/backend", hf_token=HF_TOKEN)
backend_status.update({
"client": test_client,
"connected": True,
"error_message": "",
"last_check": time.time(),
})
logger.info("βœ… Backend connection established")
return True, "🟒 Model is ready"
except Exception as e:
backend_status.update({
"client": None,
"connected": False,
"last_check": time.time(),
"error_message": str(e),
})
err = str(e).lower()
if "timeout" in err or "read operation timed out" in err:
return False, "🟑 Model is starting up. Please wait 3‑4 min."
return False, f"πŸ”΄ Backend error: {e}"
# initial probe
check_backend_connection()
# ───────── Queue Processing Worker ─────────
def queue_worker():
"""Background worker to process queue - one request at a time"""
while True:
try:
requests = queue_manager.get_next_requests()
if not requests:
time.sleep(1)
continue
# Process single request (GPU limitation)
request_id, user_data = requests[0]
logger.info(f"Starting processing request {request_id}")
process_single_request(request_id, user_data)
time.sleep(0.5)
except Exception as e:
logger.error(f"Queue worker error: {e}")
time.sleep(5)
def process_single_request(request_id: str, user_data: dict):
"""Process a single request"""
try:
img_b64 = user_data['image_b64']
category = user_data['category']
gender = user_data['gender']
if not backend_status["connected"]:
check_backend_connection()
if not backend_status["connected"]:
raise Exception("Backend not available")
client = backend_status["client"]
start_time = time.time()
result = client.predict(
img_b64,
category,
gender,
api_name="/predict",
)
processing_time = time.time() - start_time
if not result or len(result) < 4:
raise ValueError("Invalid response structure from backend")
_, overlay_b64, bg_b64, status = result
final_result = {
'overlay_b64': overlay_b64,
'bg_b64': bg_b64,
'status': status,
'processing_time': processing_time
}
queue_manager.complete_request(request_id, final_result)
except Exception as e:
queue_manager.fail_request(request_id, str(e))
# Start queue worker
worker_thread = threading.Thread(target=queue_worker, daemon=True)
worker_thread.start()
# ───────── Helpers ─────────
def image_to_base64(image: Image.Image) -> str:
if image is None:
return ""
if image.mode != "RGB":
image = image.convert("RGB")
buf = io.BytesIO()
image.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def base64_to_image(b64: str) -> Optional[Image.Image]:
if not b64:
return None
try:
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
except Exception as e:
logger.error(f"Failed to decode base64 β†’ image: {e}")
return None
# ───────── Request Management ─────────
active_requests = {} # session_id -> request_id
def submit_request(input_image: Image.Image, category: str, gender: str):
"""Submit a new request to the queue"""
if input_image is None:
return None, None, "❌ Please upload an image.", gr.update(interactive=True), ""
try:
request_id = str(uuid.uuid4())
img_b64 = image_to_base64(input_image)
user_data = {
'image_b64': img_b64,
'category': category,
'gender': gender,
'timestamp': time.time()
}
position, estimated_wait = queue_manager.add_request(request_id, user_data)
status_msg = f"πŸš€ Request submitted! Position in queue: #{position}"
if position == 1 and len(queue_manager.processing) == 0:
status_msg += " | Starting processing now..."
elif estimated_wait > 0:
status_msg += f" | Estimated wait: {estimated_wait:.0f}s"
return None, None, status_msg, gr.update(interactive=False), request_id
except Exception as e:
return None, None, f"❌ {str(e)}", gr.update(interactive=True), ""
def check_request_status(request_id: str):
"""Check the status of a request"""
if not request_id:
return None, None, "No active request", gr.update(interactive=True)
status_info = queue_manager.get_request_status(request_id)
if status_info['status'] == 'completed':
result = status_info['result']
overlay_img = base64_to_image(result['overlay_b64'])
bg_img = base64_to_image(result['bg_b64'])
status_msg = f"βœ… {result['status']} (⏱ {result['processing_time']:.1f}s)"
return overlay_img, bg_img, status_msg, gr.update(interactive=True)
elif status_info['status'] == 'failed':
return None, None, f"❌ {status_info['error']}", gr.update(interactive=True)
elif status_info['status'] == 'processing':
processing_time = status_info['time']
return None, None, f"⚑ Processing... ({processing_time:.1f}s)", gr.update(interactive=False)
elif status_info['status'] == 'queued':
position = status_info['position']
avg_time = queue_manager.stats['avg_processing_time']
estimated_wait = position * avg_time
wait_msg = f" | Est. wait: {int(estimated_wait/60)}m {int(estimated_wait%60)}s" if estimated_wait > 30 else ""
return None, None, f"⏳ In queue, position #{position}{wait_msg}", gr.update(interactive=False)
else:
return None, None, "❓ Request not found", gr.update(interactive=True)
def disable_button():
return gr.update(interactive=False)
# ───────── CSS ─────────
custom_css = """
.gradio-container {
background: linear-gradient(135deg, #3b4371 0%, #2d1b69 25%, #673ab7 50%, #8e24aa 75%, #6a1b9a 100%);
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
min-height: 100vh;
}
.contain {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 25px;
margin: 15px;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2);
backdrop-filter: blur(10px);
}
.title-container {
text-align: center;
margin-bottom: 25px;
padding: 20px;
background: linear-gradient(135deg, #673ab7, #8e24aa);
border-radius: 12px;
box-shadow: 0 5px 20px rgba(103, 58, 183, 0.4);
}
.title-container h1 {
color: white;
font-size: 2.2em;
font-weight: bold;
margin: 0;
text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.3);
}
.info-bar {
background: linear-gradient(135deg, #7c4dff, #6a1b9a);
padding: 12px;
border-radius: 8px;
margin-bottom: 20px;
color: white;
text-align: center;
font-weight: 500;
box-shadow: 0 3px 12px rgba(124, 77, 255, 0.3);
}
.section-header {
background: linear-gradient(135deg, #e1bee7, #d1c4e9);
padding: 12px;
border-radius: 8px;
margin-bottom: 15px;
border-left: 4px solid #673ab7;
}
.section-header h3 {
margin: 0;
color: #333;
font-weight: 600;
}
.input-group {
background: rgba(255, 255, 255, 0.85);
padding: 18px;
border-radius: 12px;
margin-bottom: 15px;
border: 1px solid rgba(103, 58, 183, 0.2);
box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1);
}
.result-section {
background: rgba(255, 255, 255, 0.9);
padding: 18px;
border-radius: 12px;
border: 1px solid rgba(103, 58, 183, 0.2);
box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1);
}
.tip-box {
background: linear-gradient(135deg, #f3e5f5, #e8eaf6);
padding: 10px;
border-radius: 6px;
margin: 8px 0;
border-left: 3px solid #673ab7;
color: #4a148c;
font-weight: 500;
}
button.primary {
background: linear-gradient(135deg, #673ab7, #8e24aa) !important;
border: none !important;
border-radius: 20px !important;
padding: 12px 25px !important;
color: white !important;
font-weight: bold !important;
font-size: 15px !important;
box-shadow: 0 5px 15px rgba(103, 58, 183, 0.4) !important;
}
button.primary:hover {
box-shadow: 0 8px 25px rgba(103, 58, 183, 0.6) !important;
opacity: 0.9 !important;
transform: translateY(-2px) !important;
}
label {
color: #4a148c !important;
font-weight: 600 !important;
}
input, textarea, select {
border: 1px solid rgba(103, 58, 183, 0.3) !important;
border-radius: 6px !important;
}
input:focus, textarea:focus, select:focus {
border-color: #673ab7 !important;
box-shadow: 0 0 0 2px rgba(103, 58, 183, 0.2) !important;
}
.gr-slider input[type="range"] {
accent-color: #673ab7 !important;
}
input[type="checkbox"] {
accent-color: #673ab7 !important;
}
.preserve-aspect-ratio img {
object-fit: contain !important;
width: auto !important;
max-height: 512px !important;
}
.social-links {
text-align: center;
margin: 20px 0;
}
.social-links a {
margin: 0 10px;
padding: 8px 16px;
background: #667eea;
color: white;
text-decoration: none;
border-radius: 8px;
transition: all 0.3s ease;
}
.social-links a:hover {
background: #764ba2;
transform: translateY(-2px);
}
.feature-box {
background: #f8fafc;
border: 1px solid #e2e8f0;
padding: 20px;
border-radius: 12px;
margin: 10px 0;
}
"""
# ───────── Gradio Blocks ─────────
with gr.Blocks(css=custom_css, title="Jewellery Photography Preview") as demo:
# Hero
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1 style="font-size: 2.5em;">🎨 Raresence: AI-Powered Jewellery Photo Preview</h1>
<p style="color: #666;">Upload a jewellery image, select model, and get professional photos instantly</p>
</div>
""")
# Status banner
status_html = gr.HTML()
def _update_status():
ok, msg = check_backend_connection()
cls = "status-ready" if ok else ("status-starting" if "🟑" in msg else "status-error")
return f'<div class="status-banner {cls}">{msg}</div>'
status_html.value = _update_status()
gr.Button("πŸ”„ Check Status").click(fn=_update_status, outputs=status_html)
with gr.Column():
with gr.Row():
with gr.Column(scale=0.4):
gr.HTML("""
<div class="feature-box"">
<h3>πŸ–ΌοΈ Upload Jewellery Image</h3>
<p style="color: #666; font-size: 14px;">Select a clear jewellery image for best results</p>
</div>
""")
gr.Markdown("β€Ž")
gr.Markdown("β€Ž")
input_img = gr.Image(label="Upload image", type="pil", height=400)
with gr.Column():
gr.HTML("""
<div class="feature-box">
<h3>🎨 AI Generated Results</h3>
<p style="color: #666; font-size: 14px;">Preview overlay detection and final professional background</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("Final result"):
info2 = gr.Markdown(value="### Final result")
out_bg = gr.Image(height=400)
with gr.TabItem("Detection overlay"):
info1 = gr.Markdown(value="### Detection overlay")
out_overlay = gr.Image(height=400)
run_btn = gr.Button("🎯 Generate", elem_id="button", variant="primary")
with gr.Row():
with gr.Column(scale=0.4):
gr.Markdown(value="Setting")
category = gr.Dropdown(label="Jewellery category", choices=["Rings", "Bracelets", "Watches", "Earrings"], value="Bracelets")
gender = gr.Dropdown(label="Model gender", choices=["male", "female"], value="female")
out_status = gr.Text(label="Status", interactive=False)
# ──────── Footer ────────
gr.HTML("""
<div style="text-align:center;padding:40px 20px;background:#f8fafc;border:1px solid #e2e8f0;border-radius:16px;margin:30px 0;">
<h3 style="color:#333;">πŸš€ Powered by Snapwear AI</h3>
<p style="color:#666;">
Experience the future of virtual fashion and garment visualization.
</p>
<div class="social-links">
<a href="https://snapwear.io" target="_blank">🌐 Website</a>
<a href="https://www.instagram.com/snapwearai/" target="_blank">πŸ“Έ Instagram</a>
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Texture-Transfer" target="_blank">🎨 Pattern Transfer</a>
</div>
<p style="font-size:12px;color:#999;margin-top:20px;">
Β© 2024 Snapwear AI. Professional AI tools for fashion and design.
</p>
</div>
""")
# Hidden state for request tracking
current_request_id = gr.State("")
# Wire button β†’ queue system
run_btn.click(
fn=disable_button,
inputs=None,
outputs=run_btn
).then(
fn=submit_request,
inputs=[input_img, category, gender],
outputs=[out_overlay, out_bg, out_status, run_btn, current_request_id],
show_progress=True,
)
# Auto-check status every 2 seconds for active requests
def auto_status_check(request_id):
if request_id:
return check_request_status(request_id)
return None, None, "Ready to generate", gr.update(interactive=True)
# Set up periodic status checking
demo.load(lambda: None) # Initial load
# Create a timer that checks status every 2 seconds
timer = gr.Timer(2) # Check every 2 seconds
timer.tick(
fn=auto_status_check,
inputs=[current_request_id],
outputs=[out_overlay, out_bg, out_status, run_btn]
)
# ───────── Launch ─────────
if __name__ == "__main__":
demo.queue(max_size=MAX_QUEUE_SIZE + 10, default_concurrency_limit=1).launch(share=False)