UPI matching model

Binary classifier for product variant matching (cross-store UPI).

Data Sources

Train & validation

Pairs from BigQuery (preset consideration_50k_202602):

  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260206_consideration_100k
  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260212_consideration_50k
  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260215_consideration_50k
  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260217_consideration_50k
  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260219_consideration_50k

Test

Holdout pairs from BigQuery (preset consideration_50k_202602):

  • sdp-prd-ml-taxonomy.cross_shop_clustering.matching_datasets_20260214_consideration_50k

Metadata

Product title, vendor, category, URL, and optional fields from sdp-prd-ml-taxonomy.intermediate.product_summary. Title uses COALESCE(simplified_title, product_title).

Input text format

Each product is one text string. Fields are key-value pairs joined by |; only non-empty fields are included.

Standard fields (per product):

  • Title โ€” product title (from title_1 / title_2; often COALESCE(simplified_title, product_title) in data).
  • Vendor โ€” vendor_1 / vendor_2 (e.g. product vendor).
  • Category โ€” predicted product category (predicted_category_1 / predicted_category_2), or taxonomy category if the run used use_taxonomy_product_category.
  • URL โ€” product URL (url_1 / url_2); included unless the run used exclude_url.
  • Shop โ€” shop name (shop_name_1 / shop_name_2), when present in the dataset.

Example (this run: Title | Vendor | Category | URL | Shop):

Product 1: Title: Blue Cotton Shirt | Vendor: Acme | Category: Apparel > Tops | URL: https://... | Shop: My Store
Product 2: Title: Blue Cotton Shirt | Vendor: Acme | Category: Apparel > Tops | URL: https://... | Shop: Other Store

Tokenization: Cross-encoder input is [CLS] tokens_product_1 [SEP] tokens_product_2 [SEP], with the same tokenizer and max_length (e.g. 512) as training. Use the same field order and separators for inference.

Model type

  • encoder_type: cross
  • num_labels: 2

Config

{
  "add_cross_attention": false,  // Cross-attention in encoder.
  "architectures": ["BERTCrossEncoderClassifier"],  // Model class name(s).
  "attention_probs_dropout_prob": 0.1,  // Attention dropout.
  "bert_projection_dim": null,  // Projection dim after BERT (if set).
  "bos_token_id": null,  // BOS token id.
  "catboost_dropout": 0.05,  // CatBoost MLP dropout.
  "catboost_hidden_layers": null,  // CatBoost MLP layer sizes.
  "catboost_hidden_size": 256,  // CatBoost MLP hidden size.
  "classifier_dropout": null,  // Classifier head dropout.
  "dtype": "float32",  // Model dtype.
  "encoder_type": "cross",  // Architecture: cross (BERT cross-encoder), embedding_input_only, two_stage_*, etc.
  "eos_token_id": null,  // EOS token id.
  "exclude_url_in_text": false,  // Omit URL from BERT input text.
  "gradient_checkpointing": false,  // Activate gradient checkpointing.
  "hidden_act": "gelu",  // Activation (e.g. gelu).
  "hidden_dropout_prob": 0.1,  // Hidden layer dropout.
  "hidden_size": 1024,  // BERT hidden size.
  "include_avg_price_in_text": false,  // Append average price to text.
  "initializer_range": 0.02,  // Stddev for weight init.
  "intermediate_size": 4096,  // BERT FFN intermediate dimension.
  "is_decoder": false,  // Decoder flag.
  "layer_norm_eps": 1e-12,  // LayerNorm epsilon.
  "max_position_embeddings": 512,  // Max sequence length.
  "model_type": "bert",  // Underlying transformer (e.g. bert).
  "num_attention_heads": 16,  // Number of attention heads.
  "num_catboost_features": 43,  // Length of CatBoost feature vector.
  "num_hidden_layers": 24,  // Number of BERT layers.
  "pad_token_id": 0,  // Tokenizer pad token id.
  "position_embedding_type": "absolute",  // Position embedding kind (e.g. absolute).
  "preprocess_url_in_text": false,  // Normalize URL before adding to text.
  "tie_word_embeddings": true,  // Tie input/output embeddings.
  "transformers_version": "5.2.0",  // Transformers version at save.
  "type_vocab_size": 2,  // Token type vocabulary size.
  "use_batch_norm": false,  // Use batch norm in head.
  "use_bert_layer_norm": false,  // Use LayerNorm in BERT projection.
  "use_cache": false,  // Whether to use KV cache.
  "use_catboost_features": false,  // Use CatBoost feature vector as extra input.
  "use_faiss_distance": false,  // Use FAISS distance as extra input.
  "use_standardized_description_in_text": false,  // Append standardized description (truncated).
  "use_standardized_title_in_text": false,  // Append standardized title when present.
  "use_taxonomy_product_category_in_text": false,  // Use taxonomy category instead of predicted.
  "use_variant_attributes_in_text": false,  // Append variant attributes (truncated).
  "vocab_size": 30522  // Tokenizer vocabulary size.
}

Precision threshold analysis

Recall at target precision levels (test set). For deployment, pick a row and use its Threshold Score.

Precision Target Recall Threshold Score
50.0% 84.4% 0.218
60.0% 79.1% 0.289
70.0% 72.1% 0.371
80.0% 62.2% 0.487
85.0% 54.7% 0.568
87.0% 51.6% 0.598
89.0% 47.7% 0.637
90.0% 46.1% 0.653
91.0% 43.4% 0.673
92.0% 41.1% 0.695
93.0% 35.5% 0.735
94.0% 31.5% 0.763
95.0% 28.9% 0.782
96.0% 24.6% 0.806
97.0% 21.7% 0.823
98.0% 8.2% 0.879
99.0% 6.4% 0.887

Recall threshold analysis

Precision at target recall levels (test set).

Recall Target Precision Threshold
1.0% 99.1% 0.017
2.0% 98.2% 0.027
2.5% 97.7% 0.032
3.0% 97.2% 0.037
4.0% 96.3% 0.047
5.0% 95.4% 0.057
6.0% 94.5% 0.067
7.0% 93.5% 0.077
7.5% 93.1% 0.081
8.0% 92.6% 0.086
9.0% 91.7% 0.096
10.0% 90.8% 0.106
12.0% 88.9% 0.126
15.0% 86.2% 0.156
20.0% 81.5% 0.206
25.0% 76.9% 0.255
30.0% 72.3% 0.305
40.0% 63.1% 0.404
50.0% 53.8% 0.503
60.0% 44.6% 0.603
70.0% 35.4% 0.702
80.0% 26.1% 0.801
90.0% 16.9% 0.901
91.0% 16.0% 0.911
92.0% 15.1% 0.921
93.0% 14.1% 0.930
94.0% 13.2% 0.940
95.0% 12.3% 0.950
96.0% 11.4% 0.960
99.0% 8.6% 0.990

Example

Load the model and run inference on one product pair. Format product text as in Data Sources (Title | Vendor | Category | URL | Shop).

The model directory must contain config.json (with model_type, e.g. bert, so AutoTokenizer.from_pretrained(model_path) works), weights as pytorch_model.bin or model.safetensors, and tokenizer files (e.g. tokenizer_config.json, tokenizer.json, vocab.txt).

Download from GCS: gsutil -m cp -r gs://bucket/path/to/model/* ./model_dir/ (use a real path from your run, e.g. {GCS_WORKSPACE}/models/{dataset_hash}/{run_hash}/bert_only). On macOS, if you see multiprocessing warnings, use: gsutil -o 'GSUtil:parallel_process_count=1' cp -r gs://bucket/path/to/model/* ./model_dir/.

from pathlib import Path
import torch
from transformers import AutoConfig, AutoTokenizer
from model_distillation.models.bert_classifier import BERTCrossEncoderClassifier

# Local path, or from Hugging Face Hub:
# from huggingface_hub import snapshot_download
# model_path = Path(snapshot_download(repo_id='org/repo-name', repo_type='model'))
model_path = Path("/path/to/model")  # dir: config.json, pytorch_model.bin or model.safetensors, tokenizer files

config = AutoConfig.from_pretrained(str(model_path))
model = BERTCrossEncoderClassifier(config)
if (model_path / "pytorch_model.bin").exists():
    state = torch.load(model_path / "pytorch_model.bin", map_location="cpu", weights_only=True)
else:
    from safetensors.torch import load_file
    state = load_file(model_path / "model.safetensors", device="cpu")
model.load_state_dict(state, strict=True)
tokenizer = AutoTokenizer.from_pretrained(str(model_path))
model.eval()

# One pair (same format as in Data Sources)
text1 = "Title: Blue Cotton Shirt | Vendor: Acme | Category: Apparel > Tops | URL: https://... | Shop: My Store"
text2 = "Title: Blue Cotton Shirt | Vendor: Acme | Category: Apparel > Tops | URL: https://... | Shop: Other Store"
inputs = tokenizer(text1, text2, max_length=512, truncation=True, padding="max_length", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
prob_match = torch.softmax(outputs.logits, dim=-1)[0, 1].item()  # P(match)

# Apply a threshold from the Precision threshold analysis table (e.g. 0.823 for 97% precision)
is_match = prob_match >= 0.823
Downloads last month
68
Safetensors
Model size
0.3B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support