""" DomainClassifier: Classifies documents into 7 science domains. Uses a simple linear classifier on top of text features. """ import re from typing import List, Tuple, Optional import torch import torch.nn as nn class DomainClassifier(nn.Module): """ Classifies documents into 7 science domains: 0: Physics 1: Mathematics 2: Chemistry 3: Biology 4: Earth Science 5: Space Science 6: Zoology """ # Domain keywords for rule-based fallback DOMAIN_KEYWORDS = { 0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'], 1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'], 2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'], 3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'], 4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'], 5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'], 6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'], } def __init__(self, d_model: int, num_domains: int = 7): """ Initialize domain classifier. Args: d_model: Input embedding dimension num_domains: Number of domains (7) """ super().__init__() self.d_model = d_model self.num_domains = num_domains # Simple linear classifier self.classifier = nn.Linear(d_model, num_domains) # Initialize weights nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02) nn.init.zeros_(self.classifier.bias) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Classify domain from hidden states. Args: hidden_states: (batch, seq_len, d_model) attention_mask: (batch, seq_len) Returns: Domain logits (batch, num_domains) """ # Mean pooling over sequence (masked) if attention_mask is not None: mask = attention_mask.unsqueeze(-1) # (batch, seq_len, 1) summed = (hidden_states * mask).sum(dim=1) counts = mask.sum(dim=1) pooled = summed / counts.clamp(min=1) else: pooled = hidden_states.mean(dim=1) # Classify logits = self.classifier(pooled) return logits def classify_text( self, text: str, ) -> Tuple[int, float]: """ Rule-based fallback classification from raw text. Args: text: Input text string Returns: (domain_id, confidence) """ text_lower = text.lower() # Count keyword matches per domain scores = [] for domain_id, keywords in self.DOMAIN_KEYWORDS.items(): score = sum(1 for kw in keywords if kw in text_lower) scores.append(score) if max(scores) == 0: return 0, 0.0 # Unknown -> default to physics best_domain = scores.index(max(scores)) confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0 return best_domain, confidence def compute_loss( self, logits: torch.Tensor, domain_labels: torch.Tensor, ) -> torch.Tensor: """ Compute classification loss. Args: logits: (batch, num_domains) domain_labels: (batch,) with domain IDs Returns: Cross-entropy loss """ return nn.functional.cross_entropy(logits, domain_labels) def test_domain_classifier(): """Test DomainClassifier.""" d_model = 512 batch_size = 4 seq_len = 128 classifier = DomainClassifier(d_model) # Test with random hidden states hidden = torch.randn(batch_size, seq_len, d_model) logits = classifier(hidden) print(f"Logits shape: {logits.shape}") assert logits.shape == (batch_size, 7) # Test with text texts = [ "The quantum mechanics of particles...", "Solving differential equations...", "Chemical reactions produce compounds...", "Cells contain DNA and proteins...", ] for text in texts: domain, conf = classifier.classify_text(text) print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}") # Test loss labels = torch.tensor([0, 1, 2, 3]) loss = classifier.compute_loss(logits, labels) print(f"Loss: {loss.item():.4f}") print("DomainClassifier test passed!") if __name__ == "__main__": test_domain_classifier()