Update pipeline.py
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
|
@@ -6,7 +6,7 @@ import html.parser
|
|
| 6 |
import unicodedata
|
| 7 |
import sys, os, re
|
| 8 |
|
| 9 |
-
class
|
| 10 |
|
| 11 |
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
| 12 |
self.beam_size = beam_size
|
|
@@ -153,7 +153,7 @@ class ReaccentPipeline(Pipeline):
|
|
| 153 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
| 154 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 155 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 156 |
-
normalisation_pipeline =
|
| 157 |
tokenizer=tokeniser,
|
| 158 |
batch_size=batch_size,
|
| 159 |
beam_size=beam_size)
|
|
@@ -163,7 +163,7 @@ def normalise_text(list_sents, batch_size=32, beam_size=5):
|
|
| 163 |
def normalise_from_stdin(batch_size=32, beam_size=5):
|
| 164 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 165 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 166 |
-
normalisation_pipeline =
|
| 167 |
tokenizer=tokeniser,
|
| 168 |
batch_size=batch_size,
|
| 169 |
beam_size=beam_size)
|
|
|
|
| 6 |
import unicodedata
|
| 7 |
import sys, os, re
|
| 8 |
|
| 9 |
+
class NormalisationPipeline(Pipeline):
|
| 10 |
|
| 11 |
def __init__(self, beam_size=5, batch_size=32, **kwargs):
|
| 12 |
self.beam_size = beam_size
|
|
|
|
| 153 |
def normalise_text(list_sents, batch_size=32, beam_size=5):
|
| 154 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 155 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 156 |
+
normalisation_pipeline = NormalisationPipeline(model=model,
|
| 157 |
tokenizer=tokeniser,
|
| 158 |
batch_size=batch_size,
|
| 159 |
beam_size=beam_size)
|
|
|
|
| 163 |
def normalise_from_stdin(batch_size=32, beam_size=5):
|
| 164 |
tokeniser = AutoTokenizer.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 165 |
model = AutoModelForSeq2SeqLM.from_pretrained("rbawden/modern_french_normalisation", use_auth_token=True)
|
| 166 |
+
normalisation_pipeline = NormalisationPipeline(model=model,
|
| 167 |
tokenizer=tokeniser,
|
| 168 |
batch_size=batch_size,
|
| 169 |
beam_size=beam_size)
|