Spaces:
Paused
Paused
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| from typing import Tuple | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| import os | |
| from models import ( | |
| get_device, get_tokenizers, get_image_processor, | |
| load_merger_model, get_predicated_values | |
| ) | |
| from rembg import remove | |
| from io import BytesIO | |
| # Load environment variables (optional for local dev; Spaces use web UI for env vars) | |
| if os.path.exists('.env'): | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Global constants | |
| ATTRIBUTES_LIST = ['sleeve', 'color', 'type', 'pattern', | |
| 'material', 'style', 'neck', 'gender', 'brand'] | |
| MAX_SEQ_LENGTH = 256 | |
| DECODER_MAX_SEQ_LENGTH = 64 | |
| # Global variables for model components | |
| MODEL_COMPONENTS = None | |
| MODEL_LOADED = False | |
| def initialize_model_and_tokenizers(): | |
| """Initialize model and tokenizers once""" | |
| global MODEL_COMPONENTS, MODEL_LOADED | |
| if MODEL_LOADED and MODEL_COMPONENTS: | |
| return MODEL_COMPONENTS | |
| try: | |
| print("π Loading AI model components...") | |
| device = get_device() | |
| bert_tokenizer, roberta_tokenizer = get_tokenizers() | |
| image_processor = get_image_processor() | |
| model = load_merger_model(bert_tokenizer, device) | |
| MODEL_COMPONENTS = { | |
| 'model': model, | |
| 'bert_tokenizer': bert_tokenizer, | |
| 'roberta_tokenizer': roberta_tokenizer, | |
| 'image_processor': image_processor, | |
| 'device': device | |
| } | |
| MODEL_LOADED = True | |
| print("β Model loaded successfully!") | |
| return MODEL_COMPONENTS | |
| except Exception as e: | |
| print(f"β Failed to load model: {str(e)}") | |
| raise e | |
| def validate_inputs(image, text_input: str, category: str) -> Tuple[bool, str]: | |
| """Validate that all inputs are provided""" | |
| if image is None: | |
| return False, "β Please upload an image file" | |
| if not text_input or text_input.strip() == "": | |
| return False, "β Please provide a product description" | |
| if not category: | |
| return False, "β Please select a product category" | |
| return True, "β Inputs validated successfully" | |
| def resize_image_for_display(image: Image.Image, target_size=(512, 512)) -> Image.Image: | |
| """Resize image for consistent display""" | |
| # Ensure image is RGBA | |
| image = image.convert('RGBA') | |
| # rembg expects bytes in/out | |
| with BytesIO() as inp, BytesIO() as out: | |
| image.save(inp, format="PNG") | |
| inp.seek(0) | |
| data = remove(inp.read()) | |
| out.write(data) | |
| out.seek(0) | |
| no_bg = Image.open(out).convert("RGBA") | |
| # 2) Compute new size preserving aspect ratio | |
| orig_w, orig_h = no_bg.size | |
| max_w, max_h = target_size | |
| # Determine scale factor | |
| scale = min(max_w / orig_w, max_h / orig_h) | |
| new_w = int(orig_w * scale) | |
| new_h = int(orig_h * scale) | |
| # 3) Resize with high-quality resampling | |
| resized = no_bg.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| return resized | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| """Preprocess image for model input""" | |
| if image.mode != 'RGBA': | |
| image = image.convert('RGBA') | |
| image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| return image_tensor | |
| def run_inference(image_tensor: torch.Tensor, description: str, category: str, model_components: dict) -> dict: | |
| """Run model inference using get_predicated_values API""" | |
| model = model_components['model'] | |
| bert_tokenizer = model_components['bert_tokenizer'] | |
| roberta_tokenizer = model_components['roberta_tokenizer'] | |
| image_processor = model_components['image_processor'] | |
| device = model_components['device'] | |
| pil_img = transforms.ToPILImage()(image_tensor.squeeze(0).cpu()) | |
| results = get_predicated_values( | |
| model, category, pil_img, description, | |
| image_processor, bert_tokenizer, roberta_tokenizer, device | |
| ) | |
| total_attributes = len([a for a in results if a["value"] and a["value"] != "N/A"]) | |
| avg_confidence = np.mean([a["confidence"] for a in results if a["value"] | |
| and a["value"] != "N/A"]) if total_attributes > 0 else 0 | |
| return { | |
| "attributes": results, | |
| "total_attributes": total_attributes, | |
| "avg_confidence": avg_confidence, | |
| } | |
| def get_confidence_color(confidence: float) -> str: | |
| """Get color based on confidence level""" | |
| if confidence >= 0.8: | |
| return "#28a745" # Green | |
| elif confidence >= 0.6: | |
| return "#ffc107" # Yellow | |
| else: | |
| return "#dc3545" # Red | |
| def format_results_html(results: dict) -> str: | |
| """Format results as HTML for display with dark theme""" | |
| if not results or results["total_attributes"] == 0: | |
| return """ | |
| <div class="no-results"> | |
| <h3>π No attributes extracted</h3> | |
| <p>Try with a different image or more detailed description.</p> | |
| </div> | |
| """ | |
| html = """ | |
| <div class="results-container"> | |
| <h3 class="results-title">π Extracted Attributes</h3> | |
| """ | |
| for attr in results["attributes"]: | |
| if attr["value"] != "N/A": | |
| confidence = attr["confidence"] | |
| color = get_confidence_color(confidence) | |
| html += f""" | |
| <div class="attribute-item"> | |
| <div class="attribute-name">{attr["name"].title()}</div> | |
| <div class="attribute-value">{attr["value"]}</div> | |
| <div class="confidence-badge" style="background-color: {color}">{confidence:.1%}</div> | |
| </div> | |
| """ | |
| html += f""" | |
| <div class="summary-box"> | |
| <h4>π Summary</h4> | |
| <p> | |
| <strong>{results["total_attributes"]}</strong> attributes extracted | | |
| <strong>{results["avg_confidence"]:.1%}</strong> avg confidence | | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| def create_download_files(results: dict) -> Tuple[str, str]: | |
| """Create JSON and CSV files for download""" | |
| if not results: | |
| return None, None | |
| json_content = json.dumps(results, indent=2) | |
| json_file = "attributes.json" | |
| with open(json_file, "w") as f: | |
| f.write(json_content) | |
| df = pd.DataFrame(results["attributes"]) | |
| csv_file = "attributes.csv" | |
| df.to_csv(csv_file, index=False) | |
| return json_file, csv_file | |
| def process_inputs(image, category, description, progress=gr.Progress()): | |
| """Main processing function""" | |
| global MODEL_COMPONENTS | |
| if not MODEL_LOADED: | |
| progress(0.1, desc="Loading AI model...") | |
| try: | |
| MODEL_COMPONENTS = initialize_model_and_tokenizers() | |
| except Exception as e: | |
| error_msg = f"β Failed to load model: {str(e)}" | |
| return None, error_msg, None, None, None | |
| is_valid, validation_message = validate_inputs(image, description, category) | |
| if not is_valid: | |
| return None, validation_message, None, None, None | |
| try: | |
| progress(0.3, desc="πΈ Preprocessing image...") | |
| resized_image = resize_image_for_display(image, (512, 512)) | |
| image_tensor = preprocess_image(resized_image) | |
| progress(0.7, desc="π§ Running AI inference...") | |
| results = run_inference(image_tensor, description, category, MODEL_COMPONENTS) | |
| progress(0.9, desc="π Formatting results...") | |
| results_html = format_results_html(results) | |
| json_file, csv_file = create_download_files(results) | |
| progress(1.0, desc="β Processing complete!") | |
| success_msg = f"π Successfully extracted {results['total_attributes']} attributes!" | |
| return resized_image, success_msg, results_html, json_file, csv_file | |
| except Exception as e: | |
| error_msg = f"β Processing failed: {str(e)}" | |
| return None, error_msg, None, None, None | |
| # Updated custom CSS for dark theme and refined layout | |
| custom_css = """ | |
| /* Dark theme overrides */ | |
| body, .gradio-container { | |
| background-color: #1a1a1a !important; | |
| color: #e0e0e0 !important; | |
| } | |
| .gr-blocks, .gr-row, .gr-column { | |
| background-color: #1a1a1a !important; | |
| } | |
| .input-section, .results-section { | |
| background-color: #2a2a2a !important; | |
| padding: 20px; | |
| border-radius: 15px; | |
| margin-bottom: 20px; | |
| } | |
| .gr-image, .gr-textbox, .gr-dropdown { | |
| background-color: #333 !important; | |
| color: #e0e0e0 !important; | |
| border: 1px solid #444 !important; | |
| } | |
| .gr-button { | |
| background-color: #444 !important; | |
| color: #e0e0e0 !important; | |
| border: none !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #555 !important; | |
| } | |
| /* Header styling */ | |
| .header { | |
| text-align: center; | |
| color: #e0e0e0; | |
| margin-bottom: 30px; | |
| } | |
| .header h1 { | |
| font-size: 2em; | |
| } | |
| .header p { | |
| font-size: 1.1em; | |
| color: #b0b0b0; | |
| } | |
| /* Results styling */ | |
| .results-container { | |
| padding: 20px; | |
| } | |
| .results-title { | |
| color: #e0e0e0; | |
| margin-bottom: 20px; | |
| font-size: 1.5em; | |
| } | |
| .attribute-item { | |
| background-color: #333; | |
| padding: 15px; | |
| margin-bottom: 10px; | |
| border-radius: 10px; | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| } | |
| .attribute-name { | |
| color: #e0e0e0; | |
| font-weight: bold; | |
| font-size: 1.1em; | |
| } | |
| .attribute-value { | |
| color: #b0b0b0; | |
| margin-left: 10px; | |
| } | |
| .confidence-badge { | |
| color: white; | |
| padding: 4px 8px; | |
| border-radius: 12px; | |
| font-size: 0.8em; | |
| font-weight: bold; | |
| } | |
| .summary-box { | |
| background-color: #444; | |
| color: #e0e0e0; | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin-top: 20px; | |
| text-align: center; | |
| } | |
| .no-results { | |
| padding: 20px; | |
| text-align: center; | |
| background-color: #333; | |
| border-radius: 10px; | |
| color: #e0e0e0; | |
| } | |
| /* Status message styling */ | |
| .status-positive { | |
| background-color: #1a472a !important; | |
| color: #e0e0e0 !important; | |
| padding: 10px; | |
| border-radius: 8px; | |
| } | |
| .status-negative { | |
| background-color: #471a1a !important; | |
| color: #e0e0e0 !important; | |
| padding: 10px; | |
| border-radius: 8px; | |
| } | |
| """ | |
| def create_interface(): | |
| """Create the main Gradio interface with dark theme and refined layout""" | |
| with gr.Blocks(css=custom_css, title="AI Attribute Extractor", theme=gr.themes.Soft()) as demo: | |
| # Header with dark theme styling | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π AI Attribute Extractor</h1> | |
| <p>Upload an image and provide text to extract detailed attributes using AI</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left column - Input section | |
| with gr.Column(scale=1): | |
| gr.HTML("<h2>π€ Input Section</h2>") | |
| image_input = gr.Image( | |
| label="Upload Product Image", | |
| type="pil", | |
| height=300, | |
| elem_classes=["input-section"] | |
| ) | |
| category_input = gr.Dropdown( | |
| choices=["clothing", "bags", "shoes", "accessories"], | |
| label="Product Category", | |
| value="clothing", | |
| elem_classes=["input-section"] | |
| ) | |
| text_input = gr.Textbox( | |
| label="Product Description", | |
| placeholder="Describe the product in detail...", | |
| lines=4, | |
| elem_classes=["input-section"] | |
| ) | |
| process_btn = gr.Button( | |
| "π Extract Attributes", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["primary-button"] | |
| ) | |
| status_msg = gr.HTML(label="Status") | |
| # Right column - Results section | |
| with gr.Column(scale=1): | |
| gr.HTML("<h2>π Results Section</h2>") | |
| processed_image = gr.Image( | |
| label="Processed Image", | |
| height=300, | |
| elem_classes=["results-section"] | |
| ) | |
| results_html = gr.HTML( | |
| label="Extracted Attributes", | |
| elem_classes=["results-section"] | |
| ) | |
| with gr.Row(): | |
| json_download = gr.File( | |
| label="π Download JSON", | |
| visible=False | |
| ) | |
| csv_download = gr.File( | |
| label="π Download CSV", | |
| visible=False | |
| ) | |
| # Event handlers | |
| def update_status(message: str, is_error: bool = False): | |
| """Update status message with styling""" | |
| class_name = "status-negative" if is_error else "status-positive" | |
| return f'<div class="{class_name}">{message}</div>' | |
| def process_and_update(image, category, description): | |
| """Process inputs and update all outputs""" | |
| processed_img, status, results, json_file, csv_file = process_inputs( | |
| image, category, description | |
| ) | |
| is_error = status.startswith("β") | |
| styled_status = update_status(status, is_error) | |
| json_visible = json_file is not None | |
| csv_visible = csv_file is not None | |
| return ( | |
| processed_img, | |
| styled_status, | |
| results, | |
| gr.update(value=json_file, visible=json_visible), | |
| gr.update(value=csv_file, visible=csv_visible) | |
| ) | |
| process_btn.click( | |
| fn=process_and_update, | |
| inputs=[image_input, category_input, text_input], | |
| outputs=[processed_image, status_msg, results_html, json_download, csv_download] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print("Initializing AI Attribute Extractor...") | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False, | |
| show_error=True, | |
| quiet=False | |
| ) |