""" CEFR Sentence Level Assessment Model Loads and runs inference with the metric proto k3 model """ import re from pathlib import Path from typing import List, Tuple, Dict import torch from transformers import AutoTokenizer, AutoModel class PrototypeClassifier(torch.nn.Module): """Metric-based prototype classifier for CEFR level assessment""" def __init__( self, encoder, num_labels: int, hidden_size: int, prototypes_per_class: int, temperature: float = 10.0, layer_index: int = -2, ): super().__init__() self.encoder = encoder self.num_labels = num_labels self.prototypes_per_class = prototypes_per_class self.temperature = temperature self.layer_index = layer_index self.prototypes = torch.nn.Parameter( torch.empty(num_labels, prototypes_per_class, hidden_size) ) def set_prototypes(self, proto_tensor: torch.Tensor) -> None: """Set prototype weights""" with torch.no_grad(): self.prototypes.copy_(proto_tensor) def encode(self, input_ids, attention_mask, token_type_ids=None) -> torch.Tensor: """Encode input sentences to normalized embeddings""" outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True, ) hidden = outputs.hidden_states[self.layer_index] # mean pooling mask = attention_mask.unsqueeze(-1).float() summed = torch.sum(hidden * mask, dim=1) counts = torch.clamp(mask.sum(dim=1), min=1e-9) pooled = summed / counts pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) return pooled def forward(self, input_ids, attention_mask, token_type_ids=None): """Forward pass returning logits""" x = self.encode(input_ids, attention_mask, token_type_ids) # cosine similarity with prototypes, average over K for each class protos = torch.nn.functional.normalize(self.prototypes, p=2, dim=-1) # [B, H] x [C,K,H] -> [B,C,K] sim = torch.einsum("bh,ckh->bck", x, protos) sim_mean = sim.mean(dim=2) # average over K logits = sim_mean * self.temperature return {"logits": logits} def predict(self, input_ids, attention_mask, token_type_ids=None) -> torch.Tensor: """Predict CEFR levels""" outputs = self.forward(input_ids, attention_mask, token_type_ids) return torch.argmax(outputs["logits"], dim=1) class CEFRModel: """Wrapper class for CEFR assessment model""" def __init__(self, model_path: str = None, device: str = None): """ Initialize the CEFR assessment model Args: model_path: Path to the trained model checkpoint device: Device to run inference on ('cuda' or 'cpu') """ if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) # CEFR level mapping self.id_to_label = {0: "A1", 1: "A2", 2: "B1", 3: "B2", 4: "C1", 5: "C2"} self.label_to_id = {v: k for k, v in self.id_to_label.items()} # Model parameters self.model_name = "KB/bert-base-swedish-cased" self.hidden_size = 768 self.num_labels = 6 self.prototypes_per_class = 3 self.temperature = 10.0 # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) # Load model encoder = AutoModel.from_pretrained(self.model_name) self.model = PrototypeClassifier( encoder=encoder, num_labels=self.num_labels, hidden_size=self.hidden_size, prototypes_per_class=self.prototypes_per_class, temperature=self.temperature, ) # Load trained weights if model_path is None: # Try to find the model automatically default_paths = [ "runs/metric-proto-k3/metric_proto.pt", "runs/metric-proto/metric_proto.pt", "runs/bert-baseline/bert_baseline.pt", "../runs/metric-proto-k3/metric_proto.pt", # Relative to web_app/ ] for path in default_paths: if Path(path).exists(): model_path = path print(f"Auto-detected model: {model_path}") break if model_path: # Try different relative paths possible_paths = [ Path(model_path), Path(__file__).parent / model_path, Path(__file__).parent.parent / model_path, ] checkpoint = None for path in possible_paths: if path.exists(): print(f"Loading model from {path}") checkpoint = torch.load(path, map_location=self.device, weights_only=False) break if checkpoint is None: print(f"Warning: Model file not found at {model_path}") print("Model will be initialized with random weights!") else: print("Warning: No model path specified. Model will be initialized with random weights!") checkpoint = None if checkpoint is not None: # Load model state dict if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] # Handle DataParallel state dict new_state_dict = {} for key, value in state_dict.items(): if key.startswith("model."): new_key = key[6:] # Remove 'model.' prefix else: new_key = key new_state_dict[new_key] = value self.model.load_state_dict(new_state_dict, strict=False) else: self.model.load_state_dict(checkpoint) # Load prototypes if available if "prototypes" in checkpoint: self.model.set_prototypes(checkpoint["prototypes"].to(self.device)) self.model.to(self.device) self.model.eval() def tokenize(self, texts: List[str], max_length: int = 128) -> Dict[str, torch.Tensor]: """Tokenize input texts""" encoded = self.tokenizer( texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt", ) return encoded def predict_batch(self, sentences: List[str]) -> List[Tuple[str, float]]: """ Predict CEFR levels for a batch of sentences Args: sentences: List of sentences to assess Returns: List of (level, confidence) tuples """ if not sentences: return [] # Tokenize encoded = self.tokenize(sentences) input_ids = encoded["input_ids"].to(self.device) attention_mask = encoded["attention_mask"].to(self.device) # Predict with torch.no_grad(): logits = self.model(input_ids, attention_mask)["logits"] probs = torch.softmax(logits, dim=1) predictions = torch.argmax(logits, dim=1) # Format results results = [] cpu_probs = probs.cpu() for i, pred in enumerate(predictions.cpu().numpy()): level = self.id_to_label[pred] confidence = float(cpu_probs[i][pred].item()) # Handle NaN values if torch.isnan(cpu_probs[i][pred]): confidence = 1.0 / self.num_labels results.append((level, confidence)) return results def predict_sentence(self, sentence: str) -> Tuple[str, float]: """Predict CEFR level for a single sentence""" results = self.predict_batch([sentence]) return results[0] def split_into_sentences(text: str) -> List[str]: """ Split text into sentences Args: text: Input text (Swedish) Returns: List of sentences """ # Simple sentence splitting based on punctuation # Swedish sentence endings: . ! ? # Split on punctuation followed by space and uppercase letter, or end of string sentences = re.split(r'([.!?])\s+', text) # Combine punctuation with previous sentence combined = [] for i in range(0, len(sentences) - 1, 2): if i + 1 < len(sentences): combined.append(sentences[i] + sentences[i + 1]) else: combined.append(sentences[i]) # Handle the last sentence if there's no punctuation if len(sentences) % 2 == 1 and sentences[-1].strip(): combined.append(sentences[-1]) # Clean up sentences cleaned = [] for sent in combined: sent = sent.strip() if sent: cleaned.append(sent) return cleaned def assess_text(text: str, model: CEFRModel) -> List[Dict[str, any]]: """ Assess a text and return sentence-level CEFR annotations Args: text: Input text (Swedish) model: CEFR assessment model Returns: List of dictionaries with sentence and level information """ # Split text into sentences sentences = split_into_sentences(text) if not sentences: return [] # Predict CEFR levels predictions = model.predict_batch(sentences) # Format results results = [] for sent, (level, confidence) in zip(sentences, predictions): results.append({ "sentence": sent, "level": level, "confidence": confidence, }) return results