Multi-Head 3W Extraction for Japanese Purchase Reviews
A multi-head BERT model that extracts WHO, WHERE, and WHEN from Japanese e-commerce reviews in a single forward pass.
Model Description
This model performs extractive question answering for 3W (WHO, WHERE, WHEN) extraction from Japanese purchase reviews. Unlike standard QA models that require one forward pass per question, this architecture uses a shared BERT encoder with 3 parallel QA heads, enabling ~3x faster inference.
Architecture
Input: context tokens (no question prefix)
โ
[BERT Encoder] (cl-tohoku/bert-base-japanese-v3, fine-tuned)
โ
sequence_output (batch, seq_len, 768)
โ
โโโโโโโโโโโโฌโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโ
โ WHO head โ WHERE head โ WHEN head โ
โ Linearโ2 โ Linearโ2 โ Linearโ2 โ
โโโโโโโโโโโโดโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโ
โ โ โ
(start, end) (start, end) (start, end)
Each head is a Linear(hidden_size, 2) that independently predicts start/end logits for its respective W element. The loss is the average of per-head CrossEntropy losses.
Key Features
- Single forward pass for all 3 elements (WHO, WHERE, WHEN)
- ~3x faster inference compared to standard QA approach (3 separate passes)
- Context-only input โ no question prefix tokens needed
- Multi-answer support โ can extract multiple answers per element
- Calibrated null thresholds โ per-element thresholds for no-answer prediction
Evaluation Results
Test Set (1,333 samples) โ Calibrated
| Element | F1 | Exact Match | No-Answer Accuracy | Threshold |
|---|---|---|---|---|
| WHO | 0.9555 | 0.9415 | 0.9824 | -1.0 |
| WHERE | 0.9311 | 0.9092 | 0.9863 | 1.0 |
| WHEN | 0.8385 | 0.7937 | 0.9364 | 2.0 |
| Mean | 0.9084 | 0.8815 | โ | โ |
Test Set โ Before Calibration
| Element | F1 | Exact Match | No-Answer Accuracy |
|---|---|---|---|
| WHO | 0.9533 | 0.9392 | 0.9844 |
| WHERE | 0.9326 | 0.9070 | 0.9812 |
| WHEN | 0.8147 | 0.7652 | 0.8691 |
| Mean | 0.9002 | 0.8705 | โ |
Data Distribution
| Element | Has Answer | No Answer | Multi-Answer |
|---|---|---|---|
| WHO | 20.3% | 79.7% | 12.1% of has-answer |
| WHERE | 12.2% | 87.8% | 10.2% of has-answer |
| WHEN | 37.2% | 62.8% | 16.9% of has-answer |
Training Details
| Parameter | Value |
|---|---|
| Base model | cl-tohoku/bert-base-japanese-v3 |
| Total parameters | 111,211,782 |
| Head parameters | 4,614 (0.004%) |
| Max epochs | 10 (early stopped at epoch 6) |
| Best epoch | 3 (by val mean_f1) |
| Batch size | 64 |
| Learning rate | 3e-5 |
| Warmup ratio | 0.1 |
| Weight decay | 0.01 |
| Max sequence length | 128 |
| Optimizer | AdamW |
| FP16 | Yes |
| Multi-answer training | Yes (cartesian product expansion) |
| Training samples | 5,009 (after expansion from 4,500 rows) |
| Validation samples | 500 |
| Test samples | 1,333 |
| Final training loss | 1.1352 |
| Seed | 42 |
Training Curve
| Epoch | Val Loss | WHO F1 | WHERE F1 | WHEN F1 | Mean F1 |
|---|---|---|---|---|---|
| 1 | 0.8267 | 0.9402 | 0.9031 | 0.7510 | 0.8648 |
| 2 | 0.5998 | 0.9449 | 0.9183 | 0.8446 | 0.9026 |
| 3 | 0.5931 | 0.9559 | 0.9276 | 0.8409 | 0.9082 |
| 4 | 0.6311 | 0.9552 | 0.9099 | 0.8497 | 0.9049 |
| 5 | 0.6631 | 0.9459 | 0.8994 | 0.8448 | 0.8967 |
| 6 | 0.7264 | 0.9537 | 0.9002 | 0.8507 | 0.9015 |
Usage
Loading the Model
This model uses a custom BertForMultiHeadQA class. You need to define it before loading:
import json
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Optional
from transformers import AutoTokenizer, AutoModel, AutoConfig
W_ELEMENTS = ["who", "where", "when"]
class BertForMultiHeadQA(nn.Module):
"""Multi-head BERT model for 3W extraction."""
def __init__(self, model_name_or_path: str, elements: List[str] = None, dropout: float = 0.1):
super().__init__()
self.elements = elements or W_ELEMENTS
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.encoder = AutoModel.from_pretrained(model_name_or_path)
hidden_size = self.config.hidden_size
self.dropout = nn.Dropout(dropout)
self.qa_heads = nn.ModuleDict({
elem: nn.Linear(hidden_size, 2) for elem in self.elements
})
def forward(self, input_ids, attention_mask, token_type_ids=None, **kwargs):
encoder_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if token_type_ids is not None and self.config.type_vocab_size > 1:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
sequence_output = self.dropout(outputs.last_hidden_state)
head_logits = {}
for elem, head in self.qa_heads.items():
logits = head(sequence_output)
head_logits[elem] = (logits[:, :, 0], logits[:, :, 1])
return {"head_logits": head_logits}
@classmethod
def from_pretrained(cls, load_directory: str, **kwargs):
load_directory = Path(load_directory)
with open(load_directory / "multihead_config.json") as f:
meta = json.load(f)
elements = meta.get("elements", W_ELEMENTS)
config = AutoConfig.from_pretrained(load_directory)
model = cls.__new__(cls)
nn.Module.__init__(model)
model.elements = elements
model.config = config
model.encoder = AutoModel.from_config(config)
hidden_size = config.hidden_size
model.dropout = nn.Dropout(kwargs.get("dropout", 0.1))
model.qa_heads = nn.ModuleDict({
elem: nn.Linear(hidden_size, 2) for elem in elements
})
weights_path = load_directory / "multihead_model.bin"
if not weights_path.exists():
weights_path = load_directory / "pytorch_model.bin"
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
return model
Inference
# Load model
model_path = "HALDATA/bert-base-japanese-3w-multihead-qa" # or local path
model = BertForMultiHeadQA.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# Calibrated null thresholds (recommended)
NULL_THRESHOLDS = {"who": -1.0, "where": 1.0, "when": 2.0}
def extract_3w(text: str) -> Dict[str, str]:
inputs = tokenizer(text, max_length=128, truncation=True, padding=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
results = {}
for elem in ["who", "where", "when"]:
s_logits = outputs["head_logits"][elem][0][0].cpu().numpy()
e_logits = outputs["head_logits"][elem][1][0].cpu().numpy()
# Find best span
null_score = float(s_logits[0] + e_logits[0])
n_best = 20
start_indices = np.argsort(s_logits)[-n_best:][::-1]
end_indices = np.argsort(e_logits)[-n_best:][::-1]
best_score, best_start, best_end = -1e9, 0, 0
for si in start_indices:
if si < 1: continue # skip [CLS]
for ei in end_indices:
if ei < si or ei - si + 1 > 30: continue
score = float(s_logits[si] + e_logits[ei])
if score > best_score:
best_score, best_start, best_end = score, si, ei
threshold = NULL_THRESHOLDS.get(elem, 0.0)
if best_score > -1e9 and (best_score - null_score) > threshold:
answer = tokenizer.decode(
inputs["input_ids"][0][best_start:best_end+1],
skip_special_tokens=True
).replace(" ", "").replace("##", "")
else:
answer = ""
results[elem] = answer
return results
# Example
text = "ๆฏๅญใฎ่ช็ๆฅใใฌใผใณใใจใใฆ่ณผๅ
ฅใใพใใใ"
result = extract_3w(text)
print(result)
# {'who': 'ๆฏๅญ', 'where': '', 'when': '่ช็ๆฅ'}
Batch Inference
texts = [
"ๆฏใจๅจใฎใใใซ่ณผๅ
ฅใใพใใใ",
"ใชใใณใฐใจๅฏๅฎคใงไฝฟ็จใใใใใซ่ณผๅ
ฅใใพใใใ",
"ๅญไพใจไธปไบบใๅญฆๆ กใจไผ็คพใงไฝฟใใใใซ่ฒทใใพใใใ",
]
for text in texts:
result = extract_3w(text)
print(f"Text: {text}")
for k, v in result.items():
print(f" {k.upper()}: {v or '(no answer)'}")
print()
Calibrated Null Thresholds
The model uses per-element null score diff thresholds to decide when to return "no answer":
NULL_THRESHOLDS = {
"who": -1.0, # More permissive (WHO is high-precision)
"where": 1.0, # Moderate
"when": 2.0, # More strict (WHEN has more false positives)
}
These thresholds were calibrated on the validation set by grid search over [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0].
Files
| File | Description |
|---|---|
multihead_model.bin |
Full model weights (encoder + 3 QA heads) |
config.json |
BERT encoder configuration |
multihead_config.json |
Multi-head metadata (elements, model_type) |
tokenizer_config.json |
Tokenizer configuration |
vocab.txt |
WordPiece vocabulary |
multihead_null_thresholds.json |
Calibrated thresholds with full grid search results |
multihead_test_metrics.json |
Test metrics (before calibration) |
multihead_test_metrics_calibrated.json |
Test metrics (after calibration) |
Limitations
- Designed specifically for Japanese e-commerce purchase reviews
- Extracts only WHO, WHERE, WHEN (not WHAT or WHY)
- Max input length: 128 tokens (longer reviews are truncated)
- Character-level F1 metric; may not capture semantic equivalence
- Trained on ~5,000 labeled examples; performance may vary on out-of-domain text
Citation
If you use this model, please cite:
@misc{haldata-3w-multihead-qa-2026,
title={Multi-Head 3W Extraction for Japanese Purchase Reviews},
author={Haldata},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/HALDATA/bert-base-japanese-3w-multihead-qa}
}
- Downloads last month
- 25
Model tree for HALDATA/bert-base-japanese-3w-multihead-qa
Base model
tohoku-nlp/bert-base-japanese-v3Evaluation results
- Mean F1 (calibrated)self-reported0.908
- Mean EM (calibrated)self-reported0.881