Multi-Head 3W Extraction for Japanese Purchase Reviews

A multi-head BERT model that extracts WHO, WHERE, and WHEN from Japanese e-commerce reviews in a single forward pass.

Model Description

This model performs extractive question answering for 3W (WHO, WHERE, WHEN) extraction from Japanese purchase reviews. Unlike standard QA models that require one forward pass per question, this architecture uses a shared BERT encoder with 3 parallel QA heads, enabling ~3x faster inference.

Architecture

Input: context tokens (no question prefix)
      โ†“
[BERT Encoder]  (cl-tohoku/bert-base-japanese-v3, fine-tuned)
      โ†“
sequence_output  (batch, seq_len, 768)
      โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ WHO head โ”‚ WHERE head   โ”‚ WHEN head    โ”‚
โ”‚ Linearโ†’2 โ”‚ Linearโ†’2     โ”‚ Linearโ†’2     โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
      โ†“            โ†“              โ†“
(start, end)  (start, end)  (start, end)

Each head is a Linear(hidden_size, 2) that independently predicts start/end logits for its respective W element. The loss is the average of per-head CrossEntropy losses.

Key Features

  • Single forward pass for all 3 elements (WHO, WHERE, WHEN)
  • ~3x faster inference compared to standard QA approach (3 separate passes)
  • Context-only input โ€” no question prefix tokens needed
  • Multi-answer support โ€” can extract multiple answers per element
  • Calibrated null thresholds โ€” per-element thresholds for no-answer prediction

Evaluation Results

Test Set (1,333 samples) โ€” Calibrated

Element F1 Exact Match No-Answer Accuracy Threshold
WHO 0.9555 0.9415 0.9824 -1.0
WHERE 0.9311 0.9092 0.9863 1.0
WHEN 0.8385 0.7937 0.9364 2.0
Mean 0.9084 0.8815 โ€” โ€”

Test Set โ€” Before Calibration

Element F1 Exact Match No-Answer Accuracy
WHO 0.9533 0.9392 0.9844
WHERE 0.9326 0.9070 0.9812
WHEN 0.8147 0.7652 0.8691
Mean 0.9002 0.8705 โ€”

Data Distribution

Element Has Answer No Answer Multi-Answer
WHO 20.3% 79.7% 12.1% of has-answer
WHERE 12.2% 87.8% 10.2% of has-answer
WHEN 37.2% 62.8% 16.9% of has-answer

Training Details

Parameter Value
Base model cl-tohoku/bert-base-japanese-v3
Total parameters 111,211,782
Head parameters 4,614 (0.004%)
Max epochs 10 (early stopped at epoch 6)
Best epoch 3 (by val mean_f1)
Batch size 64
Learning rate 3e-5
Warmup ratio 0.1
Weight decay 0.01
Max sequence length 128
Optimizer AdamW
FP16 Yes
Multi-answer training Yes (cartesian product expansion)
Training samples 5,009 (after expansion from 4,500 rows)
Validation samples 500
Test samples 1,333
Final training loss 1.1352
Seed 42

Training Curve

Epoch Val Loss WHO F1 WHERE F1 WHEN F1 Mean F1
1 0.8267 0.9402 0.9031 0.7510 0.8648
2 0.5998 0.9449 0.9183 0.8446 0.9026
3 0.5931 0.9559 0.9276 0.8409 0.9082
4 0.6311 0.9552 0.9099 0.8497 0.9049
5 0.6631 0.9459 0.8994 0.8448 0.8967
6 0.7264 0.9537 0.9002 0.8507 0.9015

Usage

Loading the Model

This model uses a custom BertForMultiHeadQA class. You need to define it before loading:

import json
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Optional
from transformers import AutoTokenizer, AutoModel, AutoConfig

W_ELEMENTS = ["who", "where", "when"]

class BertForMultiHeadQA(nn.Module):
    """Multi-head BERT model for 3W extraction."""

    def __init__(self, model_name_or_path: str, elements: List[str] = None, dropout: float = 0.1):
        super().__init__()
        self.elements = elements or W_ELEMENTS
        self.config = AutoConfig.from_pretrained(model_name_or_path)
        self.encoder = AutoModel.from_pretrained(model_name_or_path)
        hidden_size = self.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.qa_heads = nn.ModuleDict({
            elem: nn.Linear(hidden_size, 2) for elem in self.elements
        })

    def forward(self, input_ids, attention_mask, token_type_ids=None, **kwargs):
        encoder_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
        if token_type_ids is not None and self.config.type_vocab_size > 1:
            encoder_kwargs["token_type_ids"] = token_type_ids
        outputs = self.encoder(**encoder_kwargs)
        sequence_output = self.dropout(outputs.last_hidden_state)

        head_logits = {}
        for elem, head in self.qa_heads.items():
            logits = head(sequence_output)
            head_logits[elem] = (logits[:, :, 0], logits[:, :, 1])
        return {"head_logits": head_logits}

    @classmethod
    def from_pretrained(cls, load_directory: str, **kwargs):
        load_directory = Path(load_directory)
        with open(load_directory / "multihead_config.json") as f:
            meta = json.load(f)
        elements = meta.get("elements", W_ELEMENTS)

        config = AutoConfig.from_pretrained(load_directory)
        model = cls.__new__(cls)
        nn.Module.__init__(model)
        model.elements = elements
        model.config = config
        model.encoder = AutoModel.from_config(config)
        hidden_size = config.hidden_size
        model.dropout = nn.Dropout(kwargs.get("dropout", 0.1))
        model.qa_heads = nn.ModuleDict({
            elem: nn.Linear(hidden_size, 2) for elem in elements
        })

        weights_path = load_directory / "multihead_model.bin"
        if not weights_path.exists():
            weights_path = load_directory / "pytorch_model.bin"
        state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict)
        return model

Inference

# Load model
model_path = "HALDATA/bert-base-japanese-3w-multihead-qa"  # or local path
model = BertForMultiHeadQA.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# Calibrated null thresholds (recommended)
NULL_THRESHOLDS = {"who": -1.0, "where": 1.0, "when": 2.0}

def extract_3w(text: str) -> Dict[str, str]:
    inputs = tokenizer(text, max_length=128, truncation=True, padding=True, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    results = {}
    for elem in ["who", "where", "when"]:
        s_logits = outputs["head_logits"][elem][0][0].cpu().numpy()
        e_logits = outputs["head_logits"][elem][1][0].cpu().numpy()

        # Find best span
        null_score = float(s_logits[0] + e_logits[0])
        n_best = 20
        start_indices = np.argsort(s_logits)[-n_best:][::-1]
        end_indices = np.argsort(e_logits)[-n_best:][::-1]

        best_score, best_start, best_end = -1e9, 0, 0
        for si in start_indices:
            if si < 1: continue  # skip [CLS]
            for ei in end_indices:
                if ei < si or ei - si + 1 > 30: continue
                score = float(s_logits[si] + e_logits[ei])
                if score > best_score:
                    best_score, best_start, best_end = score, si, ei

        threshold = NULL_THRESHOLDS.get(elem, 0.0)
        if best_score > -1e9 and (best_score - null_score) > threshold:
            answer = tokenizer.decode(
                inputs["input_ids"][0][best_start:best_end+1],
                skip_special_tokens=True
            ).replace(" ", "").replace("##", "")
        else:
            answer = ""
        results[elem] = answer
    return results


# Example
text = "ๆฏๅญใฎ่ช•็”Ÿๆ—ฅใƒ—ใƒฌใ‚ผใƒณใƒˆใจใ—ใฆ่ณผๅ…ฅใ—ใพใ—ใŸใ€‚"
result = extract_3w(text)
print(result)
# {'who': 'ๆฏๅญ', 'where': '', 'when': '่ช•็”Ÿๆ—ฅ'}

Batch Inference

texts = [
    "ๆฏใจๅจ˜ใฎใŸใ‚ใซ่ณผๅ…ฅใ—ใพใ—ใŸใ€‚",
    "ใƒชใƒ“ใƒณใ‚ฐใจๅฏๅฎคใงไฝฟ็”จใ™ใ‚‹ใŸใ‚ใซ่ณผๅ…ฅใ—ใพใ—ใŸใ€‚",
    "ๅญไพ›ใจไธปไบบใŒๅญฆๆ กใจไผš็คพใงไฝฟใ†ใŸใ‚ใซ่ฒทใ„ใพใ—ใŸใ€‚",
]

for text in texts:
    result = extract_3w(text)
    print(f"Text: {text}")
    for k, v in result.items():
        print(f"  {k.upper()}: {v or '(no answer)'}")
    print()

Calibrated Null Thresholds

The model uses per-element null score diff thresholds to decide when to return "no answer":

NULL_THRESHOLDS = {
    "who": -1.0,    # More permissive (WHO is high-precision)
    "where": 1.0,   # Moderate
    "when": 2.0,    # More strict (WHEN has more false positives)
}

These thresholds were calibrated on the validation set by grid search over [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0].

Files

File Description
multihead_model.bin Full model weights (encoder + 3 QA heads)
config.json BERT encoder configuration
multihead_config.json Multi-head metadata (elements, model_type)
tokenizer_config.json Tokenizer configuration
vocab.txt WordPiece vocabulary
multihead_null_thresholds.json Calibrated thresholds with full grid search results
multihead_test_metrics.json Test metrics (before calibration)
multihead_test_metrics_calibrated.json Test metrics (after calibration)

Limitations

  • Designed specifically for Japanese e-commerce purchase reviews
  • Extracts only WHO, WHERE, WHEN (not WHAT or WHY)
  • Max input length: 128 tokens (longer reviews are truncated)
  • Character-level F1 metric; may not capture semantic equivalence
  • Trained on ~5,000 labeled examples; performance may vary on out-of-domain text

Citation

If you use this model, please cite:

@misc{haldata-3w-multihead-qa-2026,
  title={Multi-Head 3W Extraction for Japanese Purchase Reviews},
  author={Haldata},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/HALDATA/bert-base-japanese-3w-multihead-qa}
}
Downloads last month
25
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for HALDATA/bert-base-japanese-3w-multihead-qa

Finetuned
(48)
this model

Evaluation results