AVE / app.py
Mandour
remove time
9bd6442
import gradio as gr
import pandas as pd
import json
from typing import Tuple
from PIL import Image
import torch
import numpy as np
from torchvision import transforms
import os
from models import (
get_device, get_tokenizers, get_image_processor,
load_merger_model, get_predicated_values
)
from rembg import remove
from io import BytesIO
# Load environment variables (optional for local dev; Spaces use web UI for env vars)
if os.path.exists('.env'):
from dotenv import load_dotenv
load_dotenv()
# Global constants
ATTRIBUTES_LIST = ['sleeve', 'color', 'type', 'pattern',
'material', 'style', 'neck', 'gender', 'brand']
MAX_SEQ_LENGTH = 256
DECODER_MAX_SEQ_LENGTH = 64
# Global variables for model components
MODEL_COMPONENTS = None
MODEL_LOADED = False
def initialize_model_and_tokenizers():
"""Initialize model and tokenizers once"""
global MODEL_COMPONENTS, MODEL_LOADED
if MODEL_LOADED and MODEL_COMPONENTS:
return MODEL_COMPONENTS
try:
print("πŸ”„ Loading AI model components...")
device = get_device()
bert_tokenizer, roberta_tokenizer = get_tokenizers()
image_processor = get_image_processor()
model = load_merger_model(bert_tokenizer, device)
MODEL_COMPONENTS = {
'model': model,
'bert_tokenizer': bert_tokenizer,
'roberta_tokenizer': roberta_tokenizer,
'image_processor': image_processor,
'device': device
}
MODEL_LOADED = True
print("βœ… Model loaded successfully!")
return MODEL_COMPONENTS
except Exception as e:
print(f"❌ Failed to load model: {str(e)}")
raise e
def validate_inputs(image, text_input: str, category: str) -> Tuple[bool, str]:
"""Validate that all inputs are provided"""
if image is None:
return False, "❌ Please upload an image file"
if not text_input or text_input.strip() == "":
return False, "❌ Please provide a product description"
if not category:
return False, "❌ Please select a product category"
return True, "βœ… Inputs validated successfully"
def resize_image_for_display(image: Image.Image, target_size=(512, 512)) -> Image.Image:
"""Resize image for consistent display"""
# Ensure image is RGBA
image = image.convert('RGBA')
# rembg expects bytes in/out
with BytesIO() as inp, BytesIO() as out:
image.save(inp, format="PNG")
inp.seek(0)
data = remove(inp.read())
out.write(data)
out.seek(0)
no_bg = Image.open(out).convert("RGBA")
# 2) Compute new size preserving aspect ratio
orig_w, orig_h = no_bg.size
max_w, max_h = target_size
# Determine scale factor
scale = min(max_w / orig_w, max_h / orig_h)
new_w = int(orig_w * scale)
new_h = int(orig_h * scale)
# 3) Resize with high-quality resampling
resized = no_bg.resize((new_w, new_h), Image.Resampling.LANCZOS)
return resized
def preprocess_image(image: Image.Image) -> torch.Tensor:
"""Preprocess image for model input"""
if image.mode != 'RGBA':
image = image.convert('RGBA')
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
def run_inference(image_tensor: torch.Tensor, description: str, category: str, model_components: dict) -> dict:
"""Run model inference using get_predicated_values API"""
model = model_components['model']
bert_tokenizer = model_components['bert_tokenizer']
roberta_tokenizer = model_components['roberta_tokenizer']
image_processor = model_components['image_processor']
device = model_components['device']
pil_img = transforms.ToPILImage()(image_tensor.squeeze(0).cpu())
results = get_predicated_values(
model, category, pil_img, description,
image_processor, bert_tokenizer, roberta_tokenizer, device
)
total_attributes = len([a for a in results if a["value"] and a["value"] != "N/A"])
avg_confidence = np.mean([a["confidence"] for a in results if a["value"]
and a["value"] != "N/A"]) if total_attributes > 0 else 0
return {
"attributes": results,
"total_attributes": total_attributes,
"avg_confidence": avg_confidence,
}
def get_confidence_color(confidence: float) -> str:
"""Get color based on confidence level"""
if confidence >= 0.8:
return "#28a745" # Green
elif confidence >= 0.6:
return "#ffc107" # Yellow
else:
return "#dc3545" # Red
def format_results_html(results: dict) -> str:
"""Format results as HTML for display with dark theme"""
if not results or results["total_attributes"] == 0:
return """
<div class="no-results">
<h3>πŸ” No attributes extracted</h3>
<p>Try with a different image or more detailed description.</p>
</div>
"""
html = """
<div class="results-container">
<h3 class="results-title">πŸ“Š Extracted Attributes</h3>
"""
for attr in results["attributes"]:
if attr["value"] != "N/A":
confidence = attr["confidence"]
color = get_confidence_color(confidence)
html += f"""
<div class="attribute-item">
<div class="attribute-name">{attr["name"].title()}</div>
<div class="attribute-value">{attr["value"]}</div>
<div class="confidence-badge" style="background-color: {color}">{confidence:.1%}</div>
</div>
"""
html += f"""
<div class="summary-box">
<h4>πŸ“ˆ Summary</h4>
<p>
<strong>{results["total_attributes"]}</strong> attributes extracted |
<strong>{results["avg_confidence"]:.1%}</strong> avg confidence |
</p>
</div>
</div>
"""
return html
def create_download_files(results: dict) -> Tuple[str, str]:
"""Create JSON and CSV files for download"""
if not results:
return None, None
json_content = json.dumps(results, indent=2)
json_file = "attributes.json"
with open(json_file, "w") as f:
f.write(json_content)
df = pd.DataFrame(results["attributes"])
csv_file = "attributes.csv"
df.to_csv(csv_file, index=False)
return json_file, csv_file
def process_inputs(image, category, description, progress=gr.Progress()):
"""Main processing function"""
global MODEL_COMPONENTS
if not MODEL_LOADED:
progress(0.1, desc="Loading AI model...")
try:
MODEL_COMPONENTS = initialize_model_and_tokenizers()
except Exception as e:
error_msg = f"❌ Failed to load model: {str(e)}"
return None, error_msg, None, None, None
is_valid, validation_message = validate_inputs(image, description, category)
if not is_valid:
return None, validation_message, None, None, None
try:
progress(0.3, desc="πŸ“Έ Preprocessing image...")
resized_image = resize_image_for_display(image, (512, 512))
image_tensor = preprocess_image(resized_image)
progress(0.7, desc="🧠 Running AI inference...")
results = run_inference(image_tensor, description, category, MODEL_COMPONENTS)
progress(0.9, desc="πŸ“Š Formatting results...")
results_html = format_results_html(results)
json_file, csv_file = create_download_files(results)
progress(1.0, desc="βœ… Processing complete!")
success_msg = f"πŸŽ‰ Successfully extracted {results['total_attributes']} attributes!"
return resized_image, success_msg, results_html, json_file, csv_file
except Exception as e:
error_msg = f"❌ Processing failed: {str(e)}"
return None, error_msg, None, None, None
# Updated custom CSS for dark theme and refined layout
custom_css = """
/* Dark theme overrides */
body, .gradio-container {
background-color: #1a1a1a !important;
color: #e0e0e0 !important;
}
.gr-blocks, .gr-row, .gr-column {
background-color: #1a1a1a !important;
}
.input-section, .results-section {
background-color: #2a2a2a !important;
padding: 20px;
border-radius: 15px;
margin-bottom: 20px;
}
.gr-image, .gr-textbox, .gr-dropdown {
background-color: #333 !important;
color: #e0e0e0 !important;
border: 1px solid #444 !important;
}
.gr-button {
background-color: #444 !important;
color: #e0e0e0 !important;
border: none !important;
}
.gr-button:hover {
background-color: #555 !important;
}
/* Header styling */
.header {
text-align: center;
color: #e0e0e0;
margin-bottom: 30px;
}
.header h1 {
font-size: 2em;
}
.header p {
font-size: 1.1em;
color: #b0b0b0;
}
/* Results styling */
.results-container {
padding: 20px;
}
.results-title {
color: #e0e0e0;
margin-bottom: 20px;
font-size: 1.5em;
}
.attribute-item {
background-color: #333;
padding: 15px;
margin-bottom: 10px;
border-radius: 10px;
display: flex;
justify-content: space-between;
align-items: center;
}
.attribute-name {
color: #e0e0e0;
font-weight: bold;
font-size: 1.1em;
}
.attribute-value {
color: #b0b0b0;
margin-left: 10px;
}
.confidence-badge {
color: white;
padding: 4px 8px;
border-radius: 12px;
font-size: 0.8em;
font-weight: bold;
}
.summary-box {
background-color: #444;
color: #e0e0e0;
padding: 15px;
border-radius: 10px;
margin-top: 20px;
text-align: center;
}
.no-results {
padding: 20px;
text-align: center;
background-color: #333;
border-radius: 10px;
color: #e0e0e0;
}
/* Status message styling */
.status-positive {
background-color: #1a472a !important;
color: #e0e0e0 !important;
padding: 10px;
border-radius: 8px;
}
.status-negative {
background-color: #471a1a !important;
color: #e0e0e0 !important;
padding: 10px;
border-radius: 8px;
}
"""
def create_interface():
"""Create the main Gradio interface with dark theme and refined layout"""
with gr.Blocks(css=custom_css, title="AI Attribute Extractor", theme=gr.themes.Soft()) as demo:
# Header with dark theme styling
gr.HTML("""
<div class="header">
<h1>πŸ” AI Attribute Extractor</h1>
<p>Upload an image and provide text to extract detailed attributes using AI</p>
</div>
""")
with gr.Row():
# Left column - Input section
with gr.Column(scale=1):
gr.HTML("<h2>πŸ“€ Input Section</h2>")
image_input = gr.Image(
label="Upload Product Image",
type="pil",
height=300,
elem_classes=["input-section"]
)
category_input = gr.Dropdown(
choices=["clothing", "bags", "shoes", "accessories"],
label="Product Category",
value="clothing",
elem_classes=["input-section"]
)
text_input = gr.Textbox(
label="Product Description",
placeholder="Describe the product in detail...",
lines=4,
elem_classes=["input-section"]
)
process_btn = gr.Button(
"πŸš€ Extract Attributes",
variant="primary",
size="lg",
elem_classes=["primary-button"]
)
status_msg = gr.HTML(label="Status")
# Right column - Results section
with gr.Column(scale=1):
gr.HTML("<h2>πŸ“Š Results Section</h2>")
processed_image = gr.Image(
label="Processed Image",
height=300,
elem_classes=["results-section"]
)
results_html = gr.HTML(
label="Extracted Attributes",
elem_classes=["results-section"]
)
with gr.Row():
json_download = gr.File(
label="πŸ“„ Download JSON",
visible=False
)
csv_download = gr.File(
label="πŸ“Š Download CSV",
visible=False
)
# Event handlers
def update_status(message: str, is_error: bool = False):
"""Update status message with styling"""
class_name = "status-negative" if is_error else "status-positive"
return f'<div class="{class_name}">{message}</div>'
def process_and_update(image, category, description):
"""Process inputs and update all outputs"""
processed_img, status, results, json_file, csv_file = process_inputs(
image, category, description
)
is_error = status.startswith("❌")
styled_status = update_status(status, is_error)
json_visible = json_file is not None
csv_visible = csv_file is not None
return (
processed_img,
styled_status,
results,
gr.update(value=json_file, visible=json_visible),
gr.update(value=csv_file, visible=csv_visible)
)
process_btn.click(
fn=process_and_update,
inputs=[image_input, category_input, text_input],
outputs=[processed_image, status_msg, results_html, json_download, csv_download]
)
return demo
if __name__ == "__main__":
print("Initializing AI Attribute Extractor...")
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True,
quiet=False
)