lamossta commited on
Commit
2d31cef
·
1 Parent(s): 9f3aa4a

pipeline classes

Browse files
src/pipelines/fasttext_pipeline.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import fasttext
4
+
5
+ from src.models.fasttext import predict_samples
6
+
7
+
8
+ class FastTextPipeline:
9
+ def __init__(self, model_path: Path, model_name: str):
10
+ self.model = fasttext.load_model(str(model_path))
11
+ self.model_name = model_name
12
+
13
+ def run(
14
+ self,
15
+ samples: list[dict],
16
+ max_len: int = 256,
17
+ batch_size: int = 32,
18
+ deduplicate: bool = False,
19
+ ) -> list[dict]:
20
+ return predict_samples(
21
+ self.model, samples, deduplicate=deduplicate,
22
+ )
src/pipelines/predict_all_pipeline.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import onnxruntime as ort
5
+
6
+ from src.models.inference import build_tokenizer, predict
7
+ from src.pipelines.fasttext_pipeline import FastTextPipeline
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class PredictAllPipeline:
13
+ def __init__(self):
14
+ self.models: dict[str, dict] = {}
15
+ self.fasttext_models: dict[str, FastTextPipeline] = {}
16
+
17
+ def add_model(
18
+ self,
19
+ name: str,
20
+ onnx_path: Path,
21
+ mode: str,
22
+ ) -> None:
23
+ log.info(f"Adding {name} to PredictAllPipeline")
24
+ self.models[name] = {
25
+ "session": ort.InferenceSession(str(onnx_path)),
26
+ "tokenizer": build_tokenizer(mode),
27
+ "mode": mode,
28
+ }
29
+
30
+ def add_fasttext(self, name: str, model_path: Path) -> None:
31
+ log.info(f"Adding {name} (fastText) to PredictAllPipeline")
32
+ self.fasttext_models[name] = FastTextPipeline(model_path, name)
33
+
34
+ def run(
35
+ self,
36
+ samples: list[dict],
37
+ max_len: int = 256,
38
+ batch_size: int = 32,
39
+ deduplicate: bool = False,
40
+ ) -> dict[str, list[dict]]:
41
+ results: dict[str, list[dict]] = {}
42
+ for name, m in self.models.items():
43
+ results[name] = predict(
44
+ samples=samples,
45
+ session=m["session"],
46
+ tokenizer=m["tokenizer"],
47
+ mode=m["mode"],
48
+ max_len=max_len,
49
+ batch_size=batch_size,
50
+ deduplicate=deduplicate,
51
+ )
52
+ for name, ft in self.fasttext_models.items():
53
+ results[name] = ft.run(
54
+ samples=samples,
55
+ deduplicate=deduplicate,
56
+ )
57
+ return results
src/pipelines/predict_pipeline.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import onnxruntime as ort
4
+
5
+ from src.models.inference import build_tokenizer, predict
6
+
7
+
8
+ class PredictPipeline:
9
+ def __init__(
10
+ self,
11
+ onnx_path: Path,
12
+ mode: str,
13
+ model_name: str,
14
+ ):
15
+ self.session = ort.InferenceSession(str(onnx_path))
16
+ self.tokenizer = build_tokenizer(mode)
17
+ self.mode = mode
18
+ self.model_name = model_name
19
+
20
+ def run(
21
+ self,
22
+ samples: list[dict],
23
+ max_len: int = 256,
24
+ batch_size: int = 32,
25
+ deduplicate: bool = False,
26
+ ) -> list[dict]:
27
+ return predict(
28
+ samples=samples,
29
+ session=self.session,
30
+ tokenizer=self.tokenizer,
31
+ mode=self.mode,
32
+ max_len=max_len,
33
+ batch_size=batch_size,
34
+ deduplicate=deduplicate,
35
+ )