Spaces:
Sleeping
Sleeping
| """ | |
| CEFR Sentence Level Assessment Model | |
| Loads and runs inference with the metric proto k3 model | |
| """ | |
| import re | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| class PrototypeClassifier(torch.nn.Module): | |
| """Metric-based prototype classifier for CEFR level assessment""" | |
| def __init__( | |
| self, | |
| encoder, | |
| num_labels: int, | |
| hidden_size: int, | |
| prototypes_per_class: int, | |
| temperature: float = 10.0, | |
| layer_index: int = -2, | |
| ): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.num_labels = num_labels | |
| self.prototypes_per_class = prototypes_per_class | |
| self.temperature = temperature | |
| self.layer_index = layer_index | |
| self.prototypes = torch.nn.Parameter( | |
| torch.empty(num_labels, prototypes_per_class, hidden_size) | |
| ) | |
| def set_prototypes(self, proto_tensor: torch.Tensor) -> None: | |
| """Set prototype weights""" | |
| with torch.no_grad(): | |
| self.prototypes.copy_(proto_tensor) | |
| def encode(self, input_ids, attention_mask, token_type_ids=None) -> torch.Tensor: | |
| """Encode input sentences to normalized embeddings""" | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| output_hidden_states=True, | |
| ) | |
| hidden = outputs.hidden_states[self.layer_index] | |
| # mean pooling | |
| mask = attention_mask.unsqueeze(-1).float() | |
| summed = torch.sum(hidden * mask, dim=1) | |
| counts = torch.clamp(mask.sum(dim=1), min=1e-9) | |
| pooled = summed / counts | |
| pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) | |
| return pooled | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| """Forward pass returning logits""" | |
| x = self.encode(input_ids, attention_mask, token_type_ids) | |
| # cosine similarity with prototypes, average over K for each class | |
| protos = torch.nn.functional.normalize(self.prototypes, p=2, dim=-1) | |
| # [B, H] x [C,K,H] -> [B,C,K] | |
| sim = torch.einsum("bh,ckh->bck", x, protos) | |
| sim_mean = sim.mean(dim=2) # average over K | |
| logits = sim_mean * self.temperature | |
| return {"logits": logits} | |
| def predict(self, input_ids, attention_mask, token_type_ids=None) -> torch.Tensor: | |
| """Predict CEFR levels""" | |
| outputs = self.forward(input_ids, attention_mask, token_type_ids) | |
| return torch.argmax(outputs["logits"], dim=1) | |
| class CEFRModel: | |
| """Wrapper class for CEFR assessment model""" | |
| def __init__(self, model_path: str = None, device: str = None): | |
| """ | |
| Initialize the CEFR assessment model | |
| Args: | |
| model_path: Path to the trained model checkpoint | |
| device: Device to run inference on ('cuda' or 'cpu') | |
| """ | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| # CEFR level mapping | |
| self.id_to_label = {0: "A1", 1: "A2", 2: "B1", 3: "B2", 4: "C1", 5: "C2"} | |
| self.label_to_id = {v: k for k, v in self.id_to_label.items()} | |
| # Model parameters | |
| self.model_name = "KB/bert-base-swedish-cased" | |
| self.hidden_size = 768 | |
| self.num_labels = 6 | |
| self.prototypes_per_class = 3 | |
| self.temperature = 10.0 | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| # Load model | |
| encoder = AutoModel.from_pretrained(self.model_name) | |
| self.model = PrototypeClassifier( | |
| encoder=encoder, | |
| num_labels=self.num_labels, | |
| hidden_size=self.hidden_size, | |
| prototypes_per_class=self.prototypes_per_class, | |
| temperature=self.temperature, | |
| ) | |
| # Load trained weights | |
| if model_path is None: | |
| # Try to find the model automatically | |
| default_paths = [ | |
| "runs/metric-proto-k3/metric_proto.pt", | |
| "runs/metric-proto/metric_proto.pt", | |
| "runs/bert-baseline/bert_baseline.pt", | |
| "../runs/metric-proto-k3/metric_proto.pt", # Relative to web_app/ | |
| ] | |
| for path in default_paths: | |
| if Path(path).exists(): | |
| model_path = path | |
| print(f"Auto-detected model: {model_path}") | |
| break | |
| if model_path: | |
| # Try different relative paths | |
| possible_paths = [ | |
| Path(model_path), | |
| Path(__file__).parent / model_path, | |
| Path(__file__).parent.parent / model_path, | |
| ] | |
| checkpoint = None | |
| for path in possible_paths: | |
| if path.exists(): | |
| print(f"Loading model from {path}") | |
| checkpoint = torch.load(path, map_location=self.device, weights_only=False) | |
| break | |
| if checkpoint is None: | |
| print(f"Warning: Model file not found at {model_path}") | |
| print("Model will be initialized with random weights!") | |
| else: | |
| print("Warning: No model path specified. Model will be initialized with random weights!") | |
| checkpoint = None | |
| if checkpoint is not None: | |
| # Load model state dict | |
| if "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| # Handle DataParallel state dict | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("model."): | |
| new_key = key[6:] # Remove 'model.' prefix | |
| else: | |
| new_key = key | |
| new_state_dict[new_key] = value | |
| self.model.load_state_dict(new_state_dict, strict=False) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| # Load prototypes if available | |
| if "prototypes" in checkpoint: | |
| self.model.set_prototypes(checkpoint["prototypes"].to(self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def tokenize(self, texts: List[str], max_length: int = 128) -> Dict[str, torch.Tensor]: | |
| """Tokenize input texts""" | |
| encoded = self.tokenizer( | |
| texts, | |
| truncation=True, | |
| padding=True, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| return encoded | |
| def predict_batch(self, sentences: List[str]) -> List[Tuple[str, float]]: | |
| """ | |
| Predict CEFR levels for a batch of sentences | |
| Args: | |
| sentences: List of sentences to assess | |
| Returns: | |
| List of (level, confidence) tuples | |
| """ | |
| if not sentences: | |
| return [] | |
| # Tokenize | |
| encoded = self.tokenize(sentences) | |
| input_ids = encoded["input_ids"].to(self.device) | |
| attention_mask = encoded["attention_mask"].to(self.device) | |
| # Predict | |
| with torch.no_grad(): | |
| logits = self.model(input_ids, attention_mask)["logits"] | |
| probs = torch.softmax(logits, dim=1) | |
| predictions = torch.argmax(logits, dim=1) | |
| # Format results | |
| results = [] | |
| cpu_probs = probs.cpu() | |
| for i, pred in enumerate(predictions.cpu().numpy()): | |
| level = self.id_to_label[pred] | |
| confidence = float(cpu_probs[i][pred].item()) | |
| # Handle NaN values | |
| if torch.isnan(cpu_probs[i][pred]): | |
| confidence = 1.0 / self.num_labels | |
| results.append((level, confidence)) | |
| return results | |
| def predict_sentence(self, sentence: str) -> Tuple[str, float]: | |
| """Predict CEFR level for a single sentence""" | |
| results = self.predict_batch([sentence]) | |
| return results[0] | |
| def split_into_sentences(text: str) -> List[str]: | |
| """ | |
| Split text into sentences | |
| Args: | |
| text: Input text (Swedish) | |
| Returns: | |
| List of sentences | |
| """ | |
| # Simple sentence splitting based on punctuation | |
| # Swedish sentence endings: . ! ? | |
| # Split on punctuation followed by space and uppercase letter, or end of string | |
| sentences = re.split(r'([.!?])\s+', text) | |
| # Combine punctuation with previous sentence | |
| combined = [] | |
| for i in range(0, len(sentences) - 1, 2): | |
| if i + 1 < len(sentences): | |
| combined.append(sentences[i] + sentences[i + 1]) | |
| else: | |
| combined.append(sentences[i]) | |
| # Handle the last sentence if there's no punctuation | |
| if len(sentences) % 2 == 1 and sentences[-1].strip(): | |
| combined.append(sentences[-1]) | |
| # Clean up sentences | |
| cleaned = [] | |
| for sent in combined: | |
| sent = sent.strip() | |
| if sent: | |
| cleaned.append(sent) | |
| return cleaned | |
| def assess_text(text: str, model: CEFRModel) -> List[Dict[str, any]]: | |
| """ | |
| Assess a text and return sentence-level CEFR annotations | |
| Args: | |
| text: Input text (Swedish) | |
| model: CEFR assessment model | |
| Returns: | |
| List of dictionaries with sentence and level information | |
| """ | |
| # Split text into sentences | |
| sentences = split_into_sentences(text) | |
| if not sentences: | |
| return [] | |
| # Predict CEFR levels | |
| predictions = model.predict_batch(sentences) | |
| # Format results | |
| results = [] | |
| for sent, (level, confidence) in zip(sentences, predictions): | |
| results.append({ | |
| "sentence": sent, | |
| "level": level, | |
| "confidence": confidence, | |
| }) | |
| return results | |