Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class ExpertiseClassifier(nn.Module): | |
| """ | |
| Classifies user expertise level. | |
| """ | |
| def __init__(self, input_dim: int = 128, num_classes: int = 4): | |
| super().__init__() | |
| # store values so engine can read them | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| # classifier layer | |
| self.classifier = nn.Linear(input_dim, num_classes) | |
| # label names | |
| self.labels = [ | |
| "beginner", | |
| "intermediate", | |
| "advanced", | |
| "expert" | |
| ] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.classifier(x) | |
| def predict_label(self, logits: torch.Tensor) -> str: | |
| """ | |
| Convert logits to label. | |
| """ | |
| idx = torch.argmax(logits, dim=-1).item() | |
| return self.labels[idx] |