File size: 4,886 Bytes
bf64b03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
DomainClassifier: Classifies documents into 7 science domains.
Uses a simple linear classifier on top of text features.
"""
import re
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
class DomainClassifier(nn.Module):
"""
Classifies documents into 7 science domains:
0: Physics
1: Mathematics
2: Chemistry
3: Biology
4: Earth Science
5: Space Science
6: Zoology
"""
# Domain keywords for rule-based fallback
DOMAIN_KEYWORDS = {
0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'],
1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'],
2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'],
3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'],
4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'],
5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'],
6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'],
}
def __init__(self, d_model: int, num_domains: int = 7):
"""
Initialize domain classifier.
Args:
d_model: Input embedding dimension
num_domains: Number of domains (7)
"""
super().__init__()
self.d_model = d_model
self.num_domains = num_domains
# Simple linear classifier
self.classifier = nn.Linear(d_model, num_domains)
# Initialize weights
nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Classify domain from hidden states.
Args:
hidden_states: (batch, seq_len, d_model)
attention_mask: (batch, seq_len)
Returns:
Domain logits (batch, num_domains)
"""
# Mean pooling over sequence (masked)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1) # (batch, seq_len, 1)
summed = (hidden_states * mask).sum(dim=1)
counts = mask.sum(dim=1)
pooled = summed / counts.clamp(min=1)
else:
pooled = hidden_states.mean(dim=1)
# Classify
logits = self.classifier(pooled)
return logits
def classify_text(
self,
text: str,
) -> Tuple[int, float]:
"""
Rule-based fallback classification from raw text.
Args:
text: Input text string
Returns:
(domain_id, confidence)
"""
text_lower = text.lower()
# Count keyword matches per domain
scores = []
for domain_id, keywords in self.DOMAIN_KEYWORDS.items():
score = sum(1 for kw in keywords if kw in text_lower)
scores.append(score)
if max(scores) == 0:
return 0, 0.0 # Unknown -> default to physics
best_domain = scores.index(max(scores))
confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0
return best_domain, confidence
def compute_loss(
self,
logits: torch.Tensor,
domain_labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute classification loss.
Args:
logits: (batch, num_domains)
domain_labels: (batch,) with domain IDs
Returns:
Cross-entropy loss
"""
return nn.functional.cross_entropy(logits, domain_labels)
def test_domain_classifier():
"""Test DomainClassifier."""
d_model = 512
batch_size = 4
seq_len = 128
classifier = DomainClassifier(d_model)
# Test with random hidden states
hidden = torch.randn(batch_size, seq_len, d_model)
logits = classifier(hidden)
print(f"Logits shape: {logits.shape}")
assert logits.shape == (batch_size, 7)
# Test with text
texts = [
"The quantum mechanics of particles...",
"Solving differential equations...",
"Chemical reactions produce compounds...",
"Cells contain DNA and proteins...",
]
for text in texts:
domain, conf = classifier.classify_text(text)
print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}")
# Test loss
labels = torch.tensor([0, 1, 2, 3])
loss = classifier.compute_loss(logits, labels)
print(f"Loss: {loss.item():.4f}")
print("DomainClassifier test passed!")
if __name__ == "__main__":
test_domain_classifier()
|