Spaces:
Sleeping
Sleeping
File size: 4,068 Bytes
5a68eda 3d6f7e4 5e9ccec 2d3d064 5a68eda e9f9002 5a68eda 5e9ccec 79a0824 e9f9002 b671a67 79a0824 e9f9002 b671a67 cd0b2d8 acf539b e9f9002 79a0824 5a68eda 2d3d064 5a68eda b671a67 e9f9002 cd0b2d8 5a68eda e9f9002 cd0b2d8 e9f9002 b671a67 e9f9002 5e9ccec cd0b2d8 e9f9002 2d3d064 79a0824 e9f9002 79a0824 e9f9002 b671a67 e9f9002 cd0b2d8 acf539b e9f9002 b671a67 acf539b e9f9002 6f71a2a 36d4766 e9f9002 ab8319c e9f9002 6f71a2a 5a68eda e9f9002 6f71a2a e9f9002 cd0b2d8 e9f9002 cd0b2d8 e9f9002 acf539b 5a68eda | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import gradio as gr
import torch
import open_clip
import numpy as np
import pandas as pd
import requests
import json
from PIL import Image
from huggingface_hub import hf_hub_download
# ============================= CARREGA TUDO =============================
REPO = "eulerLabs/amazon-br-taxonomy-jfactor"
print("Carregando arquivos do dataset...")
tax_path = hf_hub_download(REPO, "amazon-br-taxonomy.json", repo_type="dataset")
flat_path = hf_hub_download(REPO, "categories-flat.json", repo_type="dataset")
cache_path = hf_hub_download(REPO, "marqo-ecommerce-B-cache-crad-K32-aligned.npy", repo_type="dataset")
proj_path = hf_hub_download(REPO, "projection_matrix_B_to_32.npy", repo_type="dataset")
with open(tax_path) as f:
TAXONOMY = json.load(f)
with open(flat_path) as f:
CATS = json.load(f)
CACHE_32D = np.load(cache_path).astype(np.float32)
CACHE_DICT = {cat: CACHE_32D[i] for i, cat in enumerate(CATS)}
PROJ = torch.from_numpy(np.load(proj_path)).float()
model, _, preprocess = open_clip.create_model_and_transforms("hf-hub:Marqo/marqo-ecommerce-embeddings-B")
model.eval()
# ============================= CLASSIFICAÇÃO =============================
@torch.no_grad()
def classify(image=None, url="", beam_width=5):
if image is None and url.strip():
try:
image = Image.open(requests.get(url, stream=True, timeout=10).raw).convert("RGB")
except:
return pd.DataFrame({"Erro": ["Não foi possível carregar a imagem da URL"]})
if image is None:
return pd.DataFrame({"Erro": ["Faça upload ou cole uma URL"]})
emb768 = model.encode_image(preprocess(image).unsqueeze(0))
emb32 = (emb768 @ PROJ).squeeze(0)
emb32 = emb32 / emb32.norm(dim=-1, keepdim=True)
beam = [(TAXONOMY, [], 0.0)]
while beam:
candidates = []
for current_node, path_so_far, score_so_far in beam:
if not isinstance(current_node, dict):
candidates.append((path_so_far, score_so_far))
continue
for cat, child_node in current_node.items():
if cat in CACHE_DICT:
sim_raw = torch.cosine_similarity(emb32, torch.tensor(CACHE_DICT[cat]), dim=0).item()
sim = round((sim_raw + 1) / 2, 4)
new_path = path_so_far + [(cat, sim)]
candidates.append((child_node, new_path, score_so_far + sim))
if not candidates:
break
# Ordena por score médio (quanto maior, melhor)
candidates.sort(key=lambda x: x[2] / len(x[1]) if x[1] else 0, reverse=True)
beam = candidates[:beam_width]
# Pega o melhor caminho
if not beam:
return pd.DataFrame({"Resultado": ["Nenhum caminho encontrado"]})
best_path = max(beam, key=lambda x: x[2] / len(x[1]) if x[1] else 0)[1]
rows = [[i+1, cat, f"{score:.4f}"] for i, (cat, score) in enumerate(best_path)]
return pd.DataFrame(rows, columns=["Nível", "Categoria", "Score"])
# ============================= INTERFACE =============================
with gr.Blocks(theme=gr.themes.Soft(), title="Jobim Visual Guru") as demo:
gr.Markdown("# obim Visual Gu")
gr.Markdown("**!**")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Foto do produto", height=480)
url_input = gr.Textbox(label="ou cole a URL da imagem", placeholder="https://...")
gr.Examples([
"https://i.pinimg.com/736x/82/e7/47/82e747617f37fd4600634d7a27a1e561.jpg",
"https://images.unsplash.com/photo-1542291026-7eec264c27ff?w=800",
], inputs=img_input)
with gr.Column():
beam_slider = gr.Slider(1, 15, value=5, step=1, label="Precisão (Beam Width)")
btn = gr.Button("Classificar", variant="primary", size="lg")
output = gr.Dataframe(headers=["Nível", "Categoria", "Score"], row_count=12, height=560)
btn.click(classify, inputs=[img_input, url_input, beam_slider], outputs=output)
demo.launch() |