File size: 4,068 Bytes
5a68eda
3d6f7e4
5e9ccec
2d3d064
5a68eda
 
e9f9002
5a68eda
5e9ccec
79a0824
e9f9002
b671a67
79a0824
e9f9002
 
 
b671a67
cd0b2d8
acf539b
e9f9002
 
 
 
79a0824
5a68eda
2d3d064
5a68eda
 
 
 
b671a67
e9f9002
cd0b2d8
5a68eda
e9f9002
cd0b2d8
 
 
e9f9002
 
b671a67
e9f9002
5e9ccec
cd0b2d8
 
e9f9002
 
 
2d3d064
79a0824
e9f9002
 
 
 
79a0824
e9f9002
b671a67
e9f9002
 
 
 
 
 
cd0b2d8
acf539b
e9f9002
 
 
 
 
 
 
 
 
 
b671a67
acf539b
e9f9002
6f71a2a
36d4766
 
e9f9002
ab8319c
e9f9002
 
6f71a2a
5a68eda
e9f9002
 
6f71a2a
e9f9002
 
 
cd0b2d8
e9f9002
cd0b2d8
e9f9002
acf539b
5a68eda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
import open_clip
import numpy as np
import pandas as pd
import requests
import json
from PIL import Image
from huggingface_hub import hf_hub_download

# ============================= CARREGA TUDO =============================
REPO = "eulerLabs/amazon-br-taxonomy-jfactor"

print("Carregando arquivos do dataset...")
tax_path   = hf_hub_download(REPO, "amazon-br-taxonomy.json", repo_type="dataset")
flat_path  = hf_hub_download(REPO, "categories-flat.json", repo_type="dataset")
cache_path = hf_hub_download(REPO, "marqo-ecommerce-B-cache-crad-K32-aligned.npy", repo_type="dataset")
proj_path  = hf_hub_download(REPO, "projection_matrix_B_to_32.npy", repo_type="dataset")

with open(tax_path) as f:
    TAXONOMY = json.load(f)
with open(flat_path) as f:
    CATS = json.load(f)

CACHE_32D = np.load(cache_path).astype(np.float32)
CACHE_DICT = {cat: CACHE_32D[i] for i, cat in enumerate(CATS)}
PROJ = torch.from_numpy(np.load(proj_path)).float()

model, _, preprocess = open_clip.create_model_and_transforms("hf-hub:Marqo/marqo-ecommerce-embeddings-B")
model.eval()

# ============================= CLASSIFICAÇÃO =============================
@torch.no_grad()
def classify(image=None, url="", beam_width=5):
    if image is None and url.strip():
        try:
            image = Image.open(requests.get(url, stream=True, timeout=10).raw).convert("RGB")
        except:
            return pd.DataFrame({"Erro": ["Não foi possível carregar a imagem da URL"]})

    if image is None:
        return pd.DataFrame({"Erro": ["Faça upload ou cole uma URL"]})

    emb768 = model.encode_image(preprocess(image).unsqueeze(0))
    emb32 = (emb768 @ PROJ).squeeze(0)
    emb32 = emb32 / emb32.norm(dim=-1, keepdim=True)

    beam = [(TAXONOMY, [], 0.0)]  

    while beam:
        candidates = []
        for current_node, path_so_far, score_so_far in beam:
            if not isinstance(current_node, dict): 
                candidates.append((path_so_far, score_so_far))
                continue
            for cat, child_node in current_node.items():
                if cat in CACHE_DICT:
                    sim_raw = torch.cosine_similarity(emb32, torch.tensor(CACHE_DICT[cat]), dim=0).item()
                    sim = round((sim_raw + 1) / 2, 4)  
                    new_path = path_so_far + [(cat, sim)]
                    candidates.append((child_node, new_path, score_so_far + sim))

        if not candidates:
            break

        # Ordena por score médio (quanto maior, melhor)
        candidates.sort(key=lambda x: x[2] / len(x[1]) if x[1] else 0, reverse=True)
        beam = candidates[:beam_width]

    # Pega o melhor caminho
    if not beam:
        return pd.DataFrame({"Resultado": ["Nenhum caminho encontrado"]})

    best_path = max(beam, key=lambda x: x[2] / len(x[1]) if x[1] else 0)[1]
    rows = [[i+1, cat, f"{score:.4f}"] for i, (cat, score) in enumerate(best_path)]
    return pd.DataFrame(rows, columns=["Nível", "Categoria", "Score"])

# ============================= INTERFACE =============================
with gr.Blocks(theme=gr.themes.Soft(), title="Jobim Visual Guru") as demo:
    gr.Markdown("# obim Visual Gu")
    gr.Markdown("**!**")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Foto do produto", height=480)
            url_input = gr.Textbox(label="ou cole a URL da imagem", placeholder="https://...")
            gr.Examples([
                "https://i.pinimg.com/736x/82/e7/47/82e747617f37fd4600634d7a27a1e561.jpg", 
                "https://images.unsplash.com/photo-1542291026-7eec264c27ff?w=800",       
            ], inputs=img_input)

        with gr.Column():
            beam_slider = gr.Slider(1, 15, value=5, step=1, label="Precisão (Beam Width)")
            btn = gr.Button("Classificar", variant="primary", size="lg")
            output = gr.Dataframe(headers=["Nível", "Categoria", "Score"], row_count=12, height=560)

    btn.click(classify, inputs=[img_input, url_input, beam_slider], outputs=output)

demo.launch()