Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import importlib.metadata | |
| from packaging.version import Version | |
| _tv = Version(importlib.metadata.version("transformers")) | |
| if _tv >= Version("5.0.0"): | |
| raise RuntimeError( | |
| f"transformers {_tv} is installed but this project requires <5.0.0. " | |
| "Run: pip install transformers==4.57.6" | |
| ) | |
| from data.schema import Argument, Debate | |
| LABEL2ID = {"claim": 0, "counter_claim": 1, "premise": 2, "unknown": 3} | |
| ID2LABEL = {v: k for k, v in LABEL2ID.items()} | |
| _model = None | |
| _tokenizer = None | |
| _loaded_checkpoint: str | None = None | |
| _device: str | None = None | |
| def _load(checkpoint_dir: str) -> None: | |
| global _model, _tokenizer, _loaded_checkpoint, _device | |
| if _loaded_checkpoint != checkpoint_dir: | |
| import torch | |
| from transformers import ( | |
| RobertaForSequenceClassification, | |
| RobertaTokenizerFast, | |
| ) | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _tokenizer = RobertaTokenizerFast.from_pretrained(checkpoint_dir) | |
| _model = RobertaForSequenceClassification.from_pretrained( | |
| checkpoint_dir | |
| ) | |
| _model.eval() | |
| _model.to(_device) | |
| _loaded_checkpoint = checkpoint_dir | |
| def predict( | |
| text: str, | |
| parent_text: str = "", | |
| checkpoint_dir: str = "models/best", | |
| ) -> str: | |
| """Classify a single argument text. | |
| Returns one of: 'claim', 'counter_claim', 'premise', 'unknown'. | |
| Pass parent_text when the comment is a reply. | |
| """ | |
| import torch | |
| _load(checkpoint_dir) | |
| if parent_text: | |
| enc = _tokenizer( | |
| parent_text, | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding="max_length", | |
| ) | |
| else: | |
| enc = _tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding="max_length", | |
| ) | |
| enc = {k: v.to(_device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| logits = _model(**enc).logits | |
| return ID2LABEL[logits.argmax(dim=-1).item()] | |
| def predict_debate( | |
| debate: Debate, | |
| checkpoint_dir: str = "models/best", | |
| ) -> Debate: | |
| """Label every argument in a debate, return new Debate with predictions. | |
| Main entry point for Person 3 (eval) and Person 4 (failure analysis). | |
| Preserves structure and parent_id links; only arg_type is replaced. | |
| """ | |
| _load(checkpoint_dir) | |
| arg_map = {a.id: a for a in debate.arguments} | |
| labeled = [] | |
| for arg in debate.arguments: | |
| parent = arg_map.get(arg.parent_id) if arg.parent_id else None | |
| labeled.append(Argument( | |
| id=arg.id, | |
| text=arg.text, | |
| arg_type=predict( | |
| arg.text, | |
| parent.text if parent else "", | |
| checkpoint_dir, | |
| ), | |
| parent_id=arg.parent_id, | |
| author=arg.author, | |
| score=arg.score, | |
| metadata=arg.metadata, | |
| )) | |
| return Debate( | |
| id=debate.id, | |
| title=debate.title, | |
| source=debate.source, | |
| arguments=labeled, | |
| metadata=debate.metadata, | |
| ) | |
| __all__ = ["predict", "predict_debate", "LABEL2ID", "ID2LABEL"] | |