Spaces:
Running
Running
Refactor and add more tests
Browse files- app.py +5 -4
- src/baseline.py +25 -16
- tests/test_baseline.py +9 -2
- tests/test_integration.py +2 -2
app.py
CHANGED
|
@@ -16,14 +16,15 @@ def root():
|
|
| 16 |
|
| 17 |
@app.route('/baseline/fix-commas/', methods=['POST'])
|
| 18 |
def fix_commas_with_baseline():
|
|
|
|
| 19 |
data = request.get_json()
|
| 20 |
-
if
|
| 21 |
-
return make_response(jsonify({
|
| 22 |
else:
|
| 23 |
-
return make_response("Parameter '
|
| 24 |
|
| 25 |
|
| 26 |
if __name__ == '__main__':
|
| 27 |
logger.info("Loading the baseline model.")
|
| 28 |
app.baseline_pipeline = create_baseline_pipeline()
|
| 29 |
-
app.run(debug=True)
|
|
|
|
| 16 |
|
| 17 |
@app.route('/baseline/fix-commas/', methods=['POST'])
|
| 18 |
def fix_commas_with_baseline():
|
| 19 |
+
json_field_name = 's'
|
| 20 |
data = request.get_json()
|
| 21 |
+
if json_field_name in data:
|
| 22 |
+
return make_response(jsonify({json_field_name: fix_commas(app.baseline_pipeline, data['s'])}), 200)
|
| 23 |
else:
|
| 24 |
+
return make_response(f"Parameter '{json_field_name}' missing", 400)
|
| 25 |
|
| 26 |
|
| 27 |
if __name__ == '__main__':
|
| 28 |
logger.info("Loading the baseline model.")
|
| 29 |
app.baseline_pipeline = create_baseline_pipeline()
|
| 30 |
+
app.run(debug=True) # TODO get this from config or env variable
|
src/baseline.py
CHANGED
|
@@ -1,12 +1,19 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
| 2 |
|
| 3 |
|
| 4 |
-
def create_baseline_pipeline() -> NerPipeline:
|
| 5 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 6 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
| 7 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def _remove_punctuation(s: str) -> str:
|
| 11 |
to_remove = ".,?-:"
|
| 12 |
for char in to_remove:
|
|
@@ -14,23 +21,25 @@ def _remove_punctuation(s: str) -> str:
|
|
| 14 |
return s
|
| 15 |
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
|
| 19 |
-
# TODO don't accept tokens with commas inside words
|
| 20 |
-
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
|
| 21 |
current_offset = 0
|
|
|
|
| 22 |
for i in range(1, len(pipeline_json)):
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Only insert commas for the final token of a word
|
| 26 |
-
if pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith('▁'):
|
| 27 |
result = result[:current_offset] + ',' + result[current_offset:]
|
| 28 |
current_offset += 1
|
| 29 |
return result
|
| 30 |
|
| 31 |
|
| 32 |
-
def
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
| 2 |
|
| 3 |
|
| 4 |
+
def create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
|
| 5 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 6 |
+
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
| 7 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
| 8 |
|
| 9 |
|
| 10 |
+
def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
|
| 11 |
+
return _fix_commas_based_on_pipeline_output(
|
| 12 |
+
ner_pipeline(_remove_punctuation(s)),
|
| 13 |
+
s
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
def _remove_punctuation(s: str) -> str:
|
| 18 |
to_remove = ".,?-:"
|
| 19 |
for char in to_remove:
|
|
|
|
| 21 |
return s
|
| 22 |
|
| 23 |
|
| 24 |
+
def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str) -> str:
|
| 25 |
+
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
|
|
|
|
|
|
|
| 26 |
current_offset = 0
|
| 27 |
+
|
| 28 |
for i in range(1, len(pipeline_json)):
|
| 29 |
+
current_offset = _find_current_token(current_offset, i, pipeline_json, result)
|
| 30 |
+
if _should_insert_comma(i, pipeline_json):
|
|
|
|
|
|
|
| 31 |
result = result[:current_offset] + ',' + result[current_offset:]
|
| 32 |
current_offset += 1
|
| 33 |
return result
|
| 34 |
|
| 35 |
|
| 36 |
+
def _should_insert_comma(i, pipeline_json, new_word_indicator='▁') -> bool:
|
| 37 |
+
# Only insert commas for the final token of a word
|
| 38 |
+
return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _find_current_token(current_offset, i, pipeline_json, result, new_word_indicator='▁') -> int:
|
| 42 |
+
current_word = pipeline_json[i - 1]['word'].replace(new_word_indicator, '')
|
| 43 |
+
# Find the current word in the result string, starting looking at current offset
|
| 44 |
+
current_offset = result.find(current_word, current_offset) + len(current_word)
|
| 45 |
+
return current_offset
|
tests/test_baseline.py
CHANGED
|
@@ -11,7 +11,8 @@ def baseline_pipeline():
|
|
| 11 |
"test_input",
|
| 12 |
['',
|
| 13 |
'Hello world.',
|
| 14 |
-
'This test string should not have any commas inside it.'
|
|
|
|
| 15 |
)
|
| 16 |
def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
|
| 17 |
result = fix_commas(baseline_pipeline, s=test_input)
|
|
@@ -23,7 +24,13 @@ def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_inp
|
|
| 23 |
[
|
| 24 |
['I, am.', 'I am.'],
|
| 25 |
['A complex clause however it misses a comma something else and a dot...?',
|
| 26 |
-
'A complex clause, however, it misses a comma, something else and a dot...?']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
|
| 29 |
result = fix_commas(baseline_pipeline, s=test_input)
|
|
|
|
| 11 |
"test_input",
|
| 12 |
['',
|
| 13 |
'Hello world.',
|
| 14 |
+
'This test string should not have any commas inside it.',
|
| 15 |
+
'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
|
| 16 |
)
|
| 17 |
def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
|
| 18 |
result = fix_commas(baseline_pipeline, s=test_input)
|
|
|
|
| 24 |
[
|
| 25 |
['I, am.', 'I am.'],
|
| 26 |
['A complex clause however it misses a comma something else and a dot...?',
|
| 27 |
+
'A complex clause, however, it misses a comma, something else and a dot...?'],
|
| 28 |
+
['a pen an apple, \tand a pineapple!',
|
| 29 |
+
'a pen, an apple \tand a pineapple!'],
|
| 30 |
+
['Even newlines\ntabs\tand others get preserved.',
|
| 31 |
+
'Even newlines,\ntabs\tand others get preserved.'],
|
| 32 |
+
['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
|
| 33 |
+
'I had no Creativity left therefore, I come here and write useless examples for this test.']]
|
| 34 |
)
|
| 35 |
def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
|
| 36 |
result = fix_commas(baseline_pipeline, s=test_input)
|
tests/test_integration.py
CHANGED
|
@@ -29,7 +29,7 @@ def test_fix_commas_fails_on_wrong_parameters(client):
|
|
| 29 |
'Hello world.',
|
| 30 |
'This test string should not have any commas inside it.']
|
| 31 |
)
|
| 32 |
-
def
|
| 33 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
| 34 |
|
| 35 |
assert response.status_code == 200
|
|
@@ -40,7 +40,7 @@ def test_fix_commas_plain_string_unchanged(client, test_input: str):
|
|
| 40 |
"test_input, expected",
|
| 41 |
[['I am, here.', 'I am here.'],
|
| 42 |
['books pens and pencils',
|
| 43 |
-
'books, pens and pencils
|
| 44 |
)
|
| 45 |
def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
|
| 46 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
|
|
|
| 29 |
'Hello world.',
|
| 30 |
'This test string should not have any commas inside it.']
|
| 31 |
)
|
| 32 |
+
def test_fix_commas_correct_string_unchanged(client, test_input: str):
|
| 33 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
| 34 |
|
| 35 |
assert response.status_code == 200
|
|
|
|
| 40 |
"test_input, expected",
|
| 41 |
[['I am, here.', 'I am here.'],
|
| 42 |
['books pens and pencils',
|
| 43 |
+
'books, pens and pencils']]
|
| 44 |
)
|
| 45 |
def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
|
| 46 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|