|
|
import os |
|
|
|
|
|
os.environ["HUGGINGFACE_DEMO"] = "1" |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import uuid |
|
|
import shutil |
|
|
|
|
|
from app.config import get_settings |
|
|
from app.schemas.requests import Attribute |
|
|
from app.request_handler import handle_extract |
|
|
from app.services.factory import AIServiceFactory |
|
|
|
|
|
|
|
|
settings = get_settings() |
|
|
IMAGE_MAX_SIZE = 1536 |
|
|
|
|
|
|
|
|
async def forward_request( |
|
|
attributes, product_taxonomy, product_data, ai_model, pil_images |
|
|
): |
|
|
|
|
|
request_id = str(uuid.uuid4()) |
|
|
request_temp_folder = os.path.join("gradio_temp", request_id) |
|
|
os.makedirs(request_temp_folder, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
attributes = import_for_schema + attributes |
|
|
try: |
|
|
exec(attributes, globals()) |
|
|
except: |
|
|
raise gr.Error( |
|
|
"Invalid `Attribute Schema`. Please insert valid schema following the example." |
|
|
) |
|
|
|
|
|
if product_data == "": |
|
|
product_data = "{}" |
|
|
product_data_code = f"product_data_object = {product_data}" |
|
|
|
|
|
try: |
|
|
exec(product_data_code, globals()) |
|
|
except: |
|
|
raise gr.Error( |
|
|
"Invalid `Product Data`. Please insert valid dictionary or leave it empty." |
|
|
) |
|
|
|
|
|
if pil_images is None: |
|
|
raise gr.Error("Please upload image(s) of the product") |
|
|
pil_images = [pil_image[0] for pil_image in pil_images] |
|
|
img_paths = [] |
|
|
for i, pil_image in enumerate(pil_images): |
|
|
if max(pil_image.size) > IMAGE_MAX_SIZE: |
|
|
ratio = IMAGE_MAX_SIZE / max(pil_image.size) |
|
|
pil_image = pil_image.resize( |
|
|
(int(pil_image.width * ratio), int(pil_image.height * ratio)) |
|
|
) |
|
|
img_path = os.path.join(request_temp_folder, f"{i}.jpg") |
|
|
if pil_image.mode in ("RGBA", "LA") or ( |
|
|
pil_image.mode == "P" and "transparency" in pil_image.info |
|
|
): |
|
|
pil_image = pil_image.convert("RGBA") |
|
|
if pil_image.getchannel("A").getextrema() == ( |
|
|
255, |
|
|
255, |
|
|
): |
|
|
pil_image = pil_image.convert("RGB") |
|
|
image_format = "JPEG" |
|
|
else: |
|
|
image_format = "PNG" |
|
|
else: |
|
|
image_format = "JPEG" |
|
|
pil_image.save(img_path, image_format, quality=100, subsampling=0) |
|
|
img_paths.append(img_path) |
|
|
|
|
|
|
|
|
if ai_model in settings.OPENAI_MODELS: |
|
|
ai_vendor = "openai" |
|
|
elif ai_model in settings.ANTHROPIC_MODELS: |
|
|
ai_vendor = "anthropic" |
|
|
service = AIServiceFactory.get_service(ai_vendor) |
|
|
|
|
|
try: |
|
|
json_attributes = await service.extract_attributes_with_validation( |
|
|
Product, |
|
|
ai_model, |
|
|
None, |
|
|
product_taxonomy, |
|
|
product_data_object, |
|
|
img_paths=img_paths, |
|
|
) |
|
|
except: |
|
|
raise gr.Error("Failed to extract attributes. Something went wrong.") |
|
|
finally: |
|
|
|
|
|
shutil.rmtree(request_temp_folder) |
|
|
|
|
|
gr.Info("Process completed!") |
|
|
return json_attributes |
|
|
|
|
|
|
|
|
def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values): |
|
|
schema = f""" |
|
|
"{attr_name}": {{ |
|
|
"description": "{attr_desc}", |
|
|
"data_type": "{attr_type}", |
|
|
"allowed_values": [ |
|
|
{', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')]) if allowed_values != "" else ""} |
|
|
] |
|
|
}}, |
|
|
""" |
|
|
return attributes + schema, "", "", "", "" |
|
|
|
|
|
import_for_schema = """ |
|
|
from enum import Enum |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List |
|
|
""" |
|
|
|
|
|
sample_schema = """from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
class Length(BaseModel): |
|
|
maxi: int = Field(..., description="Maxi: Dress extends to the ankles or floor.") |
|
|
knee_length: int = Field(..., description="Knee Length: Dress ends around the knees.") |
|
|
mini: int = Field(..., description="Mini: Short dress that ends well above the knees.") |
|
|
midi: int = Field(..., description="Midi: Dress falls between the knee and ankle.") |
|
|
|
|
|
|
|
|
class Style(BaseModel): |
|
|
a_line: int = Field(..., description="A Line: Fitted at the top and gradually flares toward the hem, forming an 'A' shape.") |
|
|
bodycon: int = Field(..., description="Bodycon: Tight-fitting and figure-hugging, usually made with stretchy fabric.") |
|
|
column: int = Field(..., description="Column: Straight silhouette from top to bottom, with minimal shaping or flare.") |
|
|
shirt_dress: int = Field(..., description="Shirt Dress: Structured like a shirt with buttons, collar, and sleeves; may include a belt.") |
|
|
wrap_dress: int = Field(..., description="Wrap Dress: Features a front closure that wraps and ties at the side or back.") |
|
|
slip: int = Field(..., description="Slip: Lightweight, spaghetti-strap dress with minimal structure, often bias-cut.") |
|
|
smock: int = Field(..., description="Smock: Loose-fitting with gathered or shirred sections, usually on bodice or neckline.") |
|
|
corset: int = Field(..., description="Corset: Structured bodice with boning or lacing that shapes the waist.") |
|
|
jumper_dress: int = Field(..., description="Jumper Dress: Layered dress style similar to a pinafore, often more casual or thick-strapped.") |
|
|
blazer_dress: int = Field(..., description="Blazer Dress: Tailored like a blazer or suit jacket, often double-breasted or lapelled.") |
|
|
tunic: int = Field(..., description="Tunic: Loose and straight-cut, often worn short or over pants/leggings.") |
|
|
asymmetric: int = Field(..., description="Asymmetric: Dress with a non-symmetrical hem, neckline, or sleeve design.") |
|
|
shift: int = Field(..., description="Shift: Simple, straight dress with no defined waist, typically above the knee.") |
|
|
drop_waist: int = Field(..., description="Drop waist: Waistline sits low on the hips, usually with a loose top and flared skirt.") |
|
|
empire: int = Field(..., description="Empire: High waistline just below the bust, flowing skirt from there downward.") |
|
|
modest: int = Field(..., description="Modest: Covers most of the body, with high neckline, long sleeves, and longer hemline.") |
|
|
|
|
|
|
|
|
class SleeveLength(BaseModel): |
|
|
sleeveless: int = Field(..., description="Sleeveless: No sleeves.") |
|
|
three_quarters_sleeve: int = Field(..., description="Three quarters Sleeve: Sleeves that end between the elbow and wrist.") |
|
|
long_sleeve: int = Field(..., description="Long Sleeve: Sleeves that extend to the wrist.") |
|
|
short_sleeve: int = Field(..., description="Short Sleeve: Sleeves that end above the elbow.") |
|
|
strapless: int = Field(..., description="Strapless: No shoulder straps or sleeves.") |
|
|
|
|
|
|
|
|
class Neckline(BaseModel): |
|
|
v_neck: int = Field(..., description="V Neck: Neckline dips down in the shape of a 'V', varying from shallow to deep.") |
|
|
sweetheart: int = Field(..., description="Sweetheart: A heart-shaped neckline, often curving over the bust and dipping in the center.") |
|
|
round_neck: int = Field(..., description="Round Neck: Circular neckline sitting around the base of the neck.") |
|
|
halter_neck: int = Field(..., description="Halter Neck: Straps go around the neck, leaving shoulders and upper back exposed.") |
|
|
square_neck: int = Field(..., description="Square Neck: Straight horizontal cut across the chest with vertical sides, forming a square.") |
|
|
high_neck: int = Field(..., description="High Neck: Extends up the neck slightly but not folded like a turtle neck.") |
|
|
crew_neck: int = Field(..., description="Crew Neck: High, rounded neckline that sits close to the neck.") |
|
|
cowl_neck: int = Field(..., description="Cowl Neck: Draped or folded neckline that hangs in soft folds.") |
|
|
turtle_neck: int = Field(..., description="Turtle Neck: High neckline that folds over and covers the neck completely.") |
|
|
off_the_shoulder: int = Field(..., description="Off the Shoulder: Sits below the shoulders, exposing the shoulders and collarbone.") |
|
|
one_shoulder: int = Field(..., description="One Shoulder: Covers one shoulder only, leaving the other bare.") |
|
|
boat_neck: int = Field(..., description="Boat Neck: Wide, shallow neckline that runs almost horizontally from shoulder to shoulder.") |
|
|
scoop_neck: int = Field(..., description="Scoop Neck: U-shaped neckline, typically deeper than a round neck.") |
|
|
|
|
|
|
|
|
class Pattern(BaseModel): |
|
|
floral: int = Field(..., description="Floral pattern") |
|
|
stripe: int = Field(..., description="Stripe pattern") |
|
|
leopard_print: int = Field(..., description="Leopard print") |
|
|
spot: int = Field(..., description="Spot pattern") |
|
|
plain: int = Field(..., description="Plain") |
|
|
geometric: int = Field(..., description="Geometric pattern") |
|
|
logo: int = Field(..., description="Logo print") |
|
|
graphic_print: int = Field(..., description="Graphic print") |
|
|
check: int = Field(..., description="Check pattern") |
|
|
other: int = Field(..., description="Other pattern") |
|
|
|
|
|
|
|
|
class Fabric(BaseModel): |
|
|
cotton: int = Field(..., description="Cotton") |
|
|
denim: int = Field(..., description="Denim") |
|
|
jersey: int = Field(..., description="Jersey") |
|
|
linen: int = Field(..., description="Linen") |
|
|
satin: int = Field(..., description="Satin") |
|
|
silk: int = Field(..., description="Silk") |
|
|
sequin: int = Field(..., description="Sequin") |
|
|
leather: int = Field(..., description="Leather") |
|
|
velvet: int = Field(..., description="Velvet") |
|
|
knit: int = Field(..., description="Knit") |
|
|
lace: int = Field(..., description="Lace") |
|
|
suede: int = Field(..., description="Suede") |
|
|
sheer: int = Field(..., description="Sheer") |
|
|
tulle: int = Field(..., description="Tulle") |
|
|
crepe: int = Field(..., description="Crepe") |
|
|
polyester: int = Field(..., description="Polyester") |
|
|
viscose: int = Field(..., description="Viscose") |
|
|
|
|
|
|
|
|
class Features(BaseModel): |
|
|
pockets: int = Field(..., description="Has pockets") |
|
|
lined: int = Field(..., description="Lined") |
|
|
cut_out: int = Field(..., description="Cut out design") |
|
|
backless: int = Field(..., description="Backless") |
|
|
none: int = Field(..., description="No special features") |
|
|
|
|
|
|
|
|
class Closure(BaseModel): |
|
|
button: int = Field(..., description="Button closure") |
|
|
zip: int = Field(..., description="Zip closure") |
|
|
press_stud: int = Field(..., description="Press stud closure") |
|
|
clasp: int = Field(..., description="Clasp closure") |
|
|
|
|
|
|
|
|
class BodyFit(BaseModel): |
|
|
petite: int = Field(..., description="Petite fit") |
|
|
maternity: int = Field(..., description="Maternity fit") |
|
|
regular: int = Field(..., description="Regular fit") |
|
|
tall: int = Field(..., description="Tall fit") |
|
|
plus_size: int = Field(..., description="Plus size fit") |
|
|
|
|
|
|
|
|
class Occasion(BaseModel): |
|
|
beach: int = Field(..., description="Suitable for beach") |
|
|
casual: int = Field(..., description="Casual wear") |
|
|
cocktail: int = Field(..., description="Cocktail event") |
|
|
day: int = Field(..., description="Day wear") |
|
|
evening: int = Field(..., description="Evening wear") |
|
|
mother_of_the_bride: int = Field(..., description="Mother of the bride dress") |
|
|
party: int = Field(..., description="Party wear") |
|
|
prom: int = Field(..., description="Prom dress") |
|
|
wedding_guest: int = Field(..., description="Wedding guest dress") |
|
|
work: int = Field(..., description="Work attire") |
|
|
sportswear: int = Field(..., description="Sportswear") |
|
|
|
|
|
|
|
|
class Season(BaseModel): |
|
|
spring: int = Field(..., description="Spring season") |
|
|
summer: int = Field(..., description="Summer season") |
|
|
autumn: int = Field(..., description="Autumn season") |
|
|
winter: int = Field(..., description="Winter season") |
|
|
|
|
|
|
|
|
class Product(BaseModel): |
|
|
length: Length = Field(..., description="Single value ,Length of the dress") |
|
|
style: Style = Field(..., description="Can have multiple values, Style of the dress") |
|
|
sleeve_length: SleeveLength = Field(..., description="Single value ,Sleeve length of the dress") |
|
|
neckline: Neckline = Field(..., description="Single value ,Neckline of the dress") |
|
|
pattern: Pattern = Field(..., description="Can have multiple values, Pattern of the dress") |
|
|
fabric: Fabric = Field(..., description="Can have multiple values, Fabric of the dress") |
|
|
features: Features = Field(..., description="Can have multiple values, Features of the dress") |
|
|
closure: Closure = Field(..., description="Can have multiple values ,Closure of the dress") |
|
|
body_fit: BodyFit = Field(..., description="Single value ,Body fit of the dress") |
|
|
occasion: Occasion = Field(..., description="Can have multiple values ,Occasion of the dress") |
|
|
season: Season = Field(..., description="Single value ,Season of the dress") |
|
|
""" |
|
|
description = """ |
|
|
This is a simple demo for Attribution. Follow the steps below: |
|
|
|
|
|
1. Upload image(s) of a product. |
|
|
2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty. |
|
|
3. Select the AI model to use. |
|
|
4. Enter known attributes (optional). |
|
|
5. Enter the attribute schema or use the "Add Attributes" section to add attributes. |
|
|
6. Click "Extract Attributes" to get the extracted attributes. |
|
|
""" |
|
|
|
|
|
product_data_placeholder = """Example: |
|
|
{ |
|
|
"brand": "Leaf", |
|
|
"size": "M", |
|
|
"product_name": "Leaf T-shirt", |
|
|
"color": "red" |
|
|
} |
|
|
""" |
|
|
product_data_value = """ |
|
|
{ |
|
|
"data1": "", |
|
|
"data2": "" |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Internal Demo for Attribution") as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=12): |
|
|
gr.Markdown( |
|
|
"""<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>""" |
|
|
) |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=12): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gallery = gr.Gallery( |
|
|
label="Upload images of your product here", type="pil" |
|
|
) |
|
|
product_taxnomy = gr.Textbox( |
|
|
label="Product Taxonomy", |
|
|
placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')", |
|
|
lines=1, |
|
|
max_lines=1, |
|
|
) |
|
|
ai_model = gr.Dropdown( |
|
|
label="AI Model", |
|
|
choices=settings.SUPPORTED_MODELS, |
|
|
interactive=True, |
|
|
) |
|
|
product_data = gr.TextArea( |
|
|
label="Product Data (Optional)", |
|
|
placeholder=product_data_placeholder, |
|
|
value=product_data_value.strip(), |
|
|
interactive=True, |
|
|
lines=10, |
|
|
max_lines=10, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
attributes = gr.TextArea( |
|
|
label="Attribute Schema", |
|
|
value=sample_schema, |
|
|
placeholder="Enter schema here or use Add Attributes below", |
|
|
interactive=True, |
|
|
lines=30, |
|
|
max_lines=30, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Extract Attributes") |
|
|
|
|
|
with gr.Column(scale=6): |
|
|
output_json = gr.Json( |
|
|
label="Extracted Attributes", value={}, show_indices=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
forward_request, |
|
|
inputs=[attributes, product_taxnomy, product_data, ai_model, gallery], |
|
|
outputs=output_json, |
|
|
) |
|
|
|
|
|
|
|
|
attr_user = os.getenv("ATTR_USER", "1") |
|
|
attr_pass = os.getenv("ATTR_PASS", "a") |
|
|
auth = (attr_user, attr_pass) |
|
|
demo.launch(auth=auth, debug=True, ssr_mode=False) |
|
|
|