Update TorchCRF.py
Browse files- TorchCRF.py +6 -2
TorchCRF.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
class CRF(_BaseCRF):
|
| 4 |
def __init__(self, num_tags, batch_first=True, **kwargs):
|
| 5 |
-
# torchcrf.CRF doesn't accept batch_first, so we drop it
|
| 6 |
super().__init__(num_tags, **kwargs)
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from torchcrf import CRF as _BaseCRF
|
| 3 |
+
except ImportError:
|
| 4 |
+
raise ImportError(
|
| 5 |
+
"torchcrf library not found. Make sure 'torchcrf' is listed in requirements.txt."
|
| 6 |
+
)
|
| 7 |
|
| 8 |
class CRF(_BaseCRF):
|
| 9 |
def __init__(self, num_tags, batch_first=True, **kwargs):
|
|
|
|
| 10 |
super().__init__(num_tags, **kwargs)
|