Spaces:
Running
Running
Refactor, introduce CommaFixerInterface and remove duplication
Browse files
commafixer/routers/baseline.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
from fastapi import APIRouter
|
| 2 |
import logging
|
| 3 |
|
| 4 |
from commafixer.src.baseline import BaselineCommaFixer
|
| 5 |
-
|
| 6 |
|
| 7 |
logger = logging.Logger(__name__)
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -16,10 +16,4 @@ router.model = BaselineCommaFixer()
|
|
| 16 |
@router.post('/fix-commas/')
|
| 17 |
async def fix_commas_with_baseline(data: dict):
|
| 18 |
json_field_name = 's'
|
| 19 |
-
|
| 20 |
-
logger.debug('Fixing commas.')
|
| 21 |
-
return {json_field_name: router.model.fix_commas(data['s'])}
|
| 22 |
-
else:
|
| 23 |
-
msg = f"Text '{json_field_name}' missing"
|
| 24 |
-
logger.debug(msg)
|
| 25 |
-
raise HTTPException(status_code=400, detail=msg)
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
import logging
|
| 3 |
|
| 4 |
from commafixer.src.baseline import BaselineCommaFixer
|
| 5 |
+
from common import fix_commas_request_handler
|
| 6 |
|
| 7 |
logger = logging.Logger(__name__)
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 16 |
@router.post('/fix-commas/')
|
| 17 |
async def fix_commas_with_baseline(data: dict):
|
| 18 |
json_field_name = 's'
|
| 19 |
+
return fix_commas_request_handler(json_field_name, data, logger, router.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
commafixer/routers/common.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException
|
| 2 |
+
from logging import Logger
|
| 3 |
+
|
| 4 |
+
from comma_fixer_interface import CommaFixerInterface
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def fix_commas_request_handler(
|
| 8 |
+
json_field_name: str,
|
| 9 |
+
data: dict[str, str],
|
| 10 |
+
logger: Logger,
|
| 11 |
+
model: CommaFixerInterface
|
| 12 |
+
) -> dict[str, str]:
|
| 13 |
+
if json_field_name in data:
|
| 14 |
+
logger.debug('Fixing commas.')
|
| 15 |
+
return {json_field_name: model.fix_commas(data['s'])}
|
| 16 |
+
else:
|
| 17 |
+
msg = f"Text '{json_field_name}' missing"
|
| 18 |
+
logger.debug(msg)
|
| 19 |
+
raise HTTPException(status_code=400, detail=msg)
|
commafixer/routers/fixer.py
CHANGED
|
@@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
|
|
| 2 |
import logging
|
| 3 |
|
| 4 |
from commafixer.src.fixer import CommaFixer
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
logger = logging.Logger(__name__)
|
|
@@ -16,10 +17,4 @@ router.model = CommaFixer()
|
|
| 16 |
@router.post('/')
|
| 17 |
async def fix_commas(data: dict):
|
| 18 |
json_field_name = 's'
|
| 19 |
-
|
| 20 |
-
logger.debug('Fixing commas.')
|
| 21 |
-
return {json_field_name: router.model.fix_commas(data['s'])}
|
| 22 |
-
else:
|
| 23 |
-
msg = f"Text '{json_field_name}' missing"
|
| 24 |
-
logger.debug(msg)
|
| 25 |
-
raise HTTPException(status_code=400, detail=msg)
|
|
|
|
| 2 |
import logging
|
| 3 |
|
| 4 |
from commafixer.src.fixer import CommaFixer
|
| 5 |
+
from commafixer.routers.common import fix_commas_request_handler
|
| 6 |
|
| 7 |
|
| 8 |
logger = logging.Logger(__name__)
|
|
|
|
| 17 |
@router.post('/')
|
| 18 |
async def fix_commas(data: dict):
|
| 19 |
json_field_name = 's'
|
| 20 |
+
return fix_commas_request_handler(json_field_name, data, logger, router.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
commafixer/src/baseline.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
| 2 |
import re
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
|
| 8 |
It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
| 2 |
import re
|
| 3 |
|
| 4 |
+
from commafixer.src.comma_fixer_interface import CommaFixerInterface
|
| 5 |
|
| 6 |
+
|
| 7 |
+
class BaselineCommaFixer(CommaFixerInterface):
|
| 8 |
"""
|
| 9 |
A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
|
| 10 |
It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
|
commafixer/src/comma_fixer_interface.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CommaFixerInterface(ABC):
|
| 5 |
+
@abstractmethod
|
| 6 |
+
def fix_commas(self, s: str) -> str:
|
| 7 |
+
pass
|
commafixer/src/fixer.py
CHANGED
|
@@ -3,8 +3,10 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipelin
|
|
| 3 |
import nltk
|
| 4 |
import re
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
"""
|
| 9 |
A wrapper class for the fine-tuned comma fixer model.
|
| 10 |
"""
|
|
@@ -84,7 +86,7 @@ def _fix_commas_based_on_labels_and_offsets(
|
|
| 84 |
|
| 85 |
def _should_insert_comma(label, result, current_offset) -> bool:
|
| 86 |
# Only insert commas for the final token of a word, that is, if next word starts with a space.
|
| 87 |
-
# TODO
|
| 88 |
return label == 'B-COMMA' and result[current_offset].isspace()
|
| 89 |
|
| 90 |
|
|
|
|
| 3 |
import nltk
|
| 4 |
import re
|
| 5 |
|
| 6 |
+
from commafixer.src.comma_fixer_interface import CommaFixerInterface
|
| 7 |
|
| 8 |
+
|
| 9 |
+
class CommaFixer(CommaFixerInterface):
|
| 10 |
"""
|
| 11 |
A wrapper class for the fine-tuned comma fixer model.
|
| 12 |
"""
|
|
|
|
| 86 |
|
| 87 |
def _should_insert_comma(label, result, current_offset) -> bool:
|
| 88 |
# Only insert commas for the final token of a word, that is, if next word starts with a space.
|
| 89 |
+
# TODO perhaps for low confidence tokens, we should use the original decision of the user in the input?
|
| 90 |
return label == 'B-COMMA' and result[current_offset].isspace()
|
| 91 |
|
| 92 |
|