| import os |
|
|
| os.environ["HUGGINGFACE_DEMO"] = "1" |
|
|
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| |
|
|
| import gradio as gr |
| import uuid |
| import shutil |
|
|
| from app.config import get_settings |
| from app.schemas.requests import Attribute |
| from app.request_handler import handle_extract |
| from app.services.factory import AIServiceFactory |
|
|
|
|
| settings = get_settings() |
| IMAGE_MAX_SIZE = 1536 |
|
|
|
|
| async def forward_request( |
| attributes, product_taxonomy, product_data, ai_model, pil_images |
| ): |
| |
| request_id = str(uuid.uuid4()) |
| request_temp_folder = os.path.join("gradio_temp", request_id) |
| os.makedirs(request_temp_folder, exist_ok=True) |
|
|
| try: |
| |
| attributes = import_for_schema + attributes |
| try: |
| exec(attributes, globals()) |
| except: |
| raise gr.Error( |
| "Invalid `Attribute Schema`. Please insert valid schema following the example." |
| ) |
|
|
| if product_data == "": |
| product_data = "{}" |
| product_data_code = f"product_data_object = {product_data}" |
|
|
| try: |
| exec(product_data_code, globals()) |
| except: |
| raise gr.Error( |
| "Invalid `Product Data`. Please insert valid dictionary or leave it empty." |
| ) |
|
|
| if pil_images is None: |
| raise gr.Error("Please upload image(s) of the product") |
| pil_images = [pil_image[0] for pil_image in pil_images] |
| img_paths = [] |
| for i, pil_image in enumerate(pil_images): |
| if max(pil_image.size) > IMAGE_MAX_SIZE: |
| ratio = IMAGE_MAX_SIZE / max(pil_image.size) |
| pil_image = pil_image.resize( |
| (int(pil_image.width * ratio), int(pil_image.height * ratio)) |
| ) |
| img_path = os.path.join(request_temp_folder, f"{i}.jpg") |
| if pil_image.mode in ("RGBA", "LA") or ( |
| pil_image.mode == "P" and "transparency" in pil_image.info |
| ): |
| pil_image = pil_image.convert("RGBA") |
| if pil_image.getchannel("A").getextrema() == ( |
| 255, |
| 255, |
| ): |
| pil_image = pil_image.convert("RGB") |
| image_format = "JPEG" |
| else: |
| image_format = "PNG" |
| else: |
| image_format = "JPEG" |
| pil_image.save(img_path, image_format, quality=100, subsampling=0) |
| img_paths.append(img_path) |
|
|
| |
| if ai_model in settings.OPENAI_MODELS: |
| ai_vendor = "openai" |
| elif ai_model in settings.ANTHROPIC_MODELS: |
| ai_vendor = "anthropic" |
| service = AIServiceFactory.get_service(ai_vendor) |
|
|
| try: |
| json_attributes = await service.extract_attributes_with_validation( |
| Product, |
| ai_model, |
| None, |
| product_taxonomy, |
| product_data_object, |
| img_paths=img_paths, |
| ) |
| except: |
| raise gr.Error("Failed to extract attributes. Something went wrong.") |
| finally: |
| |
| shutil.rmtree(request_temp_folder) |
|
|
| gr.Info("Process completed!") |
| return json_attributes |
|
|
|
|
| def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values): |
| schema = f""" |
| "{attr_name}": {{ |
| "description": "{attr_desc}", |
| "data_type": "{attr_type}", |
| "allowed_values": [ |
| {', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')]) if allowed_values != "" else ""} |
| ] |
| }}, |
| """ |
| return attributes + schema, "", "", "", "" |
|
|
| import_for_schema = """ |
| from enum import Enum |
| from pydantic import BaseModel, Field |
| from typing import List |
| """ |
|
|
| sample_schema = """from pydantic import BaseModel, Field |
| |
| |
| class Length(BaseModel): |
| maxi: int = Field(..., description="Maxi length dress") |
| knee_length: int = Field(..., description="Knee length dress") |
| mini: int = Field(..., description="Mini dress") |
| midi: int = Field(..., description="Midi dress") |
| |
| |
| class Style(BaseModel): |
| a_line: int = Field(..., description="A Line style") |
| bodycon: int = Field(..., description="Bodycon style") |
| column: int = Field(..., description="Column style") |
| shirt_dress: int = Field(..., description="Shirt Dress") |
| wrap_dress: int = Field(..., description="Wrap Dress") |
| slip: int = Field(..., description="Slip dress") |
| kaftan: int = Field(..., description="Kaftan") |
| smock: int = Field(..., description="Smock") |
| corset: int = Field(..., description="Corset bodice") |
| pinafore: int = Field(..., description="Pinafore") |
| jumper_dress: int = Field(..., description="Jumper Dress") |
| blazer_dress: int = Field(..., description="Blazer Dress") |
| tunic: int = Field(..., description="Tunic") |
| |
| |
| class SleeveLength(BaseModel): |
| sleeveless: int = Field(..., description="Sleeveless") |
| three_quarters_sleeve: int = Field(..., description="Three quarters Sleeve") |
| long_sleeve: int = Field(..., description="Long Sleeve") |
| short_sleeve: int = Field(..., description="Short Sleeve") |
| strapless: int = Field(..., description="Strapless") |
| |
| |
| class Neckline(BaseModel): |
| v_neck: int = Field(..., description="V Neck") |
| sweetheart: int = Field(..., description="Sweetheart neckline") |
| round_neck: int = Field(..., description="Round Neck") |
| halter_neck: int = Field(..., description="Halter Neck") |
| square_neck: int = Field(..., description="Square Neck") |
| high_neck: int = Field(..., description="High Neck") |
| crew_neck: int = Field(..., description="Crew Neck") |
| cowl_neck: int = Field(..., description="Cowl Neck") |
| turtle_neck: int = Field(..., description="Turtle Neck") |
| off_the_shoulder: int = Field(..., description="Off the Shoulder") |
| one_shoulder: int = Field(..., description="One Shoulder") |
| |
| |
| class Pattern(BaseModel): |
| floral: int = Field(..., description="Floral pattern") |
| stripe: int = Field(..., description="Stripe pattern") |
| leopard_print: int = Field(..., description="Leopard print") |
| spot: int = Field(..., description="Spot pattern") |
| plain: int = Field(..., description="Plain") |
| geometric: int = Field(..., description="Geometric pattern") |
| logo: int = Field(..., description="Logo print") |
| graphic_print: int = Field(..., description="Graphic print") |
| check: int = Field(..., description="Check pattern") |
| other: int = Field(..., description="Other pattern") |
| |
| |
| class Fabric(BaseModel): |
| cotton: int = Field(..., description="Cotton") |
| denim: int = Field(..., description="Denim") |
| jersey: int = Field(..., description="Jersey") |
| linen: int = Field(..., description="Linen") |
| satin: int = Field(..., description="Satin") |
| silk: int = Field(..., description="Silk") |
| sequin: int = Field(..., description="Sequin") |
| leather: int = Field(..., description="Leather") |
| velvet: int = Field(..., description="Velvet") |
| knit: int = Field(..., description="Knit") |
| lace: int = Field(..., description="Lace") |
| suede: int = Field(..., description="Suede") |
| sheer: int = Field(..., description="Sheer") |
| tulle: int = Field(..., description="Tulle") |
| crepe: int = Field(..., description="Crepe") |
| polyester: int = Field(..., description="Polyester") |
| viscose: int = Field(..., description="Viscose") |
| |
| |
| class Features(BaseModel): |
| pockets: int = Field(..., description="Has pockets") |
| lined: int = Field(..., description="Lined") |
| cut_out: int = Field(..., description="Cut out design") |
| backless: int = Field(..., description="Backless") |
| none: int = Field(..., description="No special features") |
| |
| |
| class Closure(BaseModel): |
| button: int = Field(..., description="Button closure") |
| zip: int = Field(..., description="Zip closure") |
| press_stud: int = Field(..., description="Press stud closure") |
| clasp: int = Field(..., description="Clasp closure") |
| |
| |
| class BodyFit(BaseModel): |
| petite: int = Field(..., description="Petite fit") |
| maternity: int = Field(..., description="Maternity fit") |
| regular: int = Field(..., description="Regular fit") |
| tall: int = Field(..., description="Tall fit") |
| plus_size: int = Field(..., description="Plus size fit") |
| |
| |
| class Occasion(BaseModel): |
| beach: int = Field(..., description="Suitable for beach") |
| casual: int = Field(..., description="Casual wear") |
| cocktail: int = Field(..., description="Cocktail event") |
| day: int = Field(..., description="Day wear") |
| evening: int = Field(..., description="Evening wear") |
| mother_of_the_bride: int = Field(..., description="Mother of the bride dress") |
| party: int = Field(..., description="Party wear") |
| prom: int = Field(..., description="Prom dress") |
| wedding_guest: int = Field(..., description="Wedding guest dress") |
| work: int = Field(..., description="Work attire") |
| sportswear: int = Field(..., description="Sportswear") |
| |
| |
| class Season(BaseModel): |
| spring: int = Field(..., description="Spring season") |
| summer: int = Field(..., description="Summer season") |
| autumn: int = Field(..., description="Autumn season") |
| winter: int = Field(..., description="Winter season") |
| |
| |
| class Product(BaseModel): |
| length: Length = Field(..., description="Single value ,Length of the dress") |
| style: Style = Field(..., description="Can have multiple values, Style of the dress") |
| sleeve_length: SleeveLength = Field(..., description="Single value ,Sleeve length of the dress") |
| neckline: Neckline = Field(..., description="Single value ,Neckline of the dress") |
| pattern: Pattern = Field(..., description="Can have multiple values, Pattern of the dress") |
| fabric: Fabric = Field(..., description="Can have multiple values, Fabric of the dress") |
| features: Features = Field(..., description="Can have multiple values, Features of the dress") |
| closure: Closure = Field(..., description="Can have multiple values ,Closure of the dress") |
| body_fit: BodyFit = Field(..., description="Single value ,Body fit of the dress") |
| occasion: Occasion = Field(..., description="Can have multiple values ,Occasion of the dress") |
| season: Season = Field(..., description="Single value ,Season of the dress") |
| """ |
| description = """ |
| This is a simple demo for Attribution. Follow the steps below: |
| |
| 1. Upload image(s) of a product. |
| 2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty. |
| 3. Select the AI model to use. |
| 4. Enter known attributes (optional). |
| 5. Enter the attribute schema or use the "Add Attributes" section to add attributes. |
| 6. Click "Extract Attributes" to get the extracted attributes. |
| """ |
|
|
| product_data_placeholder = """Example: |
| { |
| "brand": "Leaf", |
| "size": "M", |
| "product_name": "Leaf T-shirt", |
| "color": "red" |
| } |
| """ |
| product_data_value = """ |
| { |
| "data1": "", |
| "data2": "" |
| } |
| """ |
|
|
| with gr.Blocks(title="Internal Demo for Attribution") as demo: |
| with gr.Row(): |
| with gr.Column(scale=12): |
| gr.Markdown( |
| """<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>""" |
| ) |
| gr.Markdown(description) |
|
|
| with gr.Row(): |
| with gr.Column(scale=12): |
| with gr.Row(): |
| with gr.Column(): |
| gallery = gr.Gallery( |
| label="Upload images of your product here", type="pil" |
| ) |
| product_taxnomy = gr.Textbox( |
| label="Product Taxonomy", |
| placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')", |
| lines=1, |
| max_lines=1, |
| ) |
| ai_model = gr.Dropdown( |
| label="AI Model", |
| choices=settings.SUPPORTED_MODELS, |
| interactive=True, |
| ) |
| product_data = gr.TextArea( |
| label="Product Data (Optional)", |
| placeholder=product_data_placeholder, |
| value=product_data_value.strip(), |
| interactive=True, |
| lines=10, |
| max_lines=10, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| with gr.Column(): |
| attributes = gr.TextArea( |
| label="Attribute Schema", |
| value=sample_schema, |
| placeholder="Enter schema here or use Add Attributes below", |
| interactive=True, |
| lines=30, |
| max_lines=30, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| with gr.Row(): |
| submit_btn = gr.Button("Extract Attributes") |
|
|
| with gr.Column(scale=6): |
| output_json = gr.Json( |
| label="Extracted Attributes", value={}, show_indices=False |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| submit_btn.click( |
| forward_request, |
| inputs=[attributes, product_taxnomy, product_data, ai_model, gallery], |
| outputs=output_json, |
| ) |
|
|
|
|
| attr_user = os.getenv("ATTR_USER", "1") |
| attr_pass = os.getenv("ATTR_PASS", "a") |
| auth = (attr_user, attr_pass) |
| demo.launch(auth=auth, debug=True, ssr_mode=False) |
|
|