Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
·
126a4c6
1
Parent(s):
a240da9
Refactor ml model handling
Browse files- src/db.py +4 -1
- src/ml.py +116 -67
- src/webservice.py +6 -0
- src/worker.py +79 -38
src/db.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import sqlite3
|
| 3 |
from contextlib import contextmanager
|
| 4 |
from typing import Generator
|
|
@@ -6,6 +7,8 @@ from typing import Generator
|
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
logger.setLevel(logging.DEBUG)
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
schema_entries = """
|
| 11 |
CREATE TABLE entries
|
|
@@ -67,7 +70,7 @@ def _get_db_connection() -> sqlite3.Connection:
|
|
| 67 |
global TABLES_CREATED
|
| 68 |
|
| 69 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
| 70 |
-
conn = sqlite3.connect(
|
| 71 |
if TABLES_CREATED:
|
| 72 |
return conn
|
| 73 |
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
import sqlite3
|
| 4 |
from contextlib import contextmanager
|
| 5 |
from typing import Generator
|
|
|
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
logger.setLevel(logging.DEBUG)
|
| 9 |
|
| 10 |
+
db_file = os.getenv("DB_FILE_NAME", "sqlite-data.db")
|
| 11 |
+
|
| 12 |
|
| 13 |
schema_entries = """
|
| 14 |
CREATE TABLE entries
|
|
|
|
| 70 |
global TABLES_CREATED
|
| 71 |
|
| 72 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
| 73 |
+
conn = sqlite3.connect(db_file, timeout=30)
|
| 74 |
if TABLES_CREATED:
|
| 75 |
return conn
|
| 76 |
|
src/ml.py
CHANGED
|
@@ -1,52 +1,126 @@
|
|
| 1 |
import abc
|
|
|
|
| 2 |
import logging
|
| 3 |
import re
|
| 4 |
|
| 5 |
import httpx
|
| 6 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
| 7 |
|
| 8 |
from base import JobInput
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
logger.setLevel(logging.DEBUG)
|
| 12 |
|
| 13 |
-
MODEL_NAME = "google/flan-t5-large"
|
| 14 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 15 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.template = "Summarize the text below in two sentences:\n\n{}"
|
| 21 |
-
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
| 22 |
-
self.generation_config.max_new_tokens = 200
|
| 23 |
-
self.generation_config.min_new_tokens = 100
|
| 24 |
-
self.generation_config.top_k = 5
|
| 25 |
-
self.generation_config.repetition_penalty = 1.5
|
| 26 |
|
| 27 |
def __call__(self, x: str) -> str:
|
| 28 |
text = self.template.format(x)
|
| 29 |
-
inputs = tokenizer(text, return_tensors="pt")
|
| 30 |
-
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
| 31 |
-
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 32 |
assert isinstance(output, str)
|
| 33 |
return output
|
| 34 |
|
| 35 |
def get_name(self) -> str:
|
| 36 |
-
return f"
|
| 37 |
|
| 38 |
|
| 39 |
-
class Tagger:
|
| 40 |
-
def __init__(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
self.template = (
|
| 42 |
"Create a list of tags for the text below. The tags should be high level "
|
| 43 |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
|
| 44 |
)
|
| 45 |
-
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
| 46 |
-
self.generation_config.max_new_tokens = 50
|
| 47 |
-
self.generation_config.min_new_tokens = 25
|
| 48 |
-
# increase the temperature to make the model more creative
|
| 49 |
-
self.generation_config.temperature = 1.5
|
| 50 |
|
| 51 |
def _extract_tags(self, text: str) -> list[str]:
|
| 52 |
tags = set()
|
|
@@ -57,46 +131,25 @@ class Tagger:
|
|
| 57 |
|
| 58 |
def __call__(self, x: str) -> list[str]:
|
| 59 |
text = self.template.format(x)
|
| 60 |
-
inputs = tokenizer(text, return_tensors="pt")
|
| 61 |
-
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
| 62 |
-
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 63 |
tags = self._extract_tags(output)
|
| 64 |
return tags
|
| 65 |
|
| 66 |
def get_name(self) -> str:
|
| 67 |
-
return f"
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class Processor(abc.ABC):
|
| 71 |
-
def __call__(self, job: JobInput) -> str:
|
| 72 |
-
_id = job.id
|
| 73 |
-
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
|
| 74 |
-
result = self.process(job)
|
| 75 |
-
logger.info(f"Finished processing input (id={_id[:8]})")
|
| 76 |
-
return result
|
| 77 |
-
|
| 78 |
-
def process(self, input: JobInput) -> str:
|
| 79 |
-
raise NotImplementedError
|
| 80 |
-
|
| 81 |
-
def match(self, input: JobInput) -> bool:
|
| 82 |
-
raise NotImplementedError
|
| 83 |
-
|
| 84 |
-
def get_name(self) -> str:
|
| 85 |
-
raise NotImplementedError
|
| 86 |
|
| 87 |
|
| 88 |
-
class
|
| 89 |
def match(self, input: JobInput) -> bool:
|
| 90 |
return True
|
| 91 |
|
| 92 |
def process(self, input: JobInput) -> str:
|
| 93 |
return input.content
|
| 94 |
|
| 95 |
-
def get_name(self) -> str:
|
| 96 |
-
return self.__class__.__name__
|
| 97 |
-
|
| 98 |
|
| 99 |
-
class
|
| 100 |
def __init__(self) -> None:
|
| 101 |
self.client = httpx.Client()
|
| 102 |
self.regex = re.compile(r"(https?://[^\s]+)")
|
|
@@ -118,26 +171,22 @@ class PlainUrlProcessor(Processor):
|
|
| 118 |
text = self.template.format(url=self.url, content=text)
|
| 119 |
return text
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
self.registry: list[Processor] = []
|
| 128 |
-
self.default_registry: list[Processor] = []
|
| 129 |
-
self.set_default_processors()
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
self.registry.append(processor)
|
| 136 |
-
|
| 137 |
-
def dispatch(self, input: JobInput) -> Processor:
|
| 138 |
-
for processor in self.registry + self.default_registry:
|
| 139 |
-
if processor.match(input):
|
| 140 |
-
return processor
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
| 1 |
import abc
|
| 2 |
+
from typing import Any
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
|
| 6 |
import httpx
|
|
|
|
| 7 |
|
| 8 |
from base import JobInput
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
logger.setLevel(logging.DEBUG)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
class Processor(abc.ABC):
|
| 15 |
+
def get_name(self) -> str:
|
| 16 |
+
return self.__class__.__name__
|
| 17 |
+
|
| 18 |
+
def __call__(self, job: JobInput) -> str:
|
| 19 |
+
_id = job.id
|
| 20 |
+
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
|
| 21 |
+
result = self.process(job)
|
| 22 |
+
logger.info(f"Finished processing input (id={_id[:8]})")
|
| 23 |
+
return result
|
| 24 |
|
| 25 |
+
@abc.abstractmethod
|
| 26 |
+
def process(self, input: JobInput) -> str:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
@abc.abstractmethod
|
| 30 |
+
def match(self, input: JobInput) -> bool:
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Summarizer(abc.ABC):
|
| 35 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def get_name(self) -> str:
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
@abc.abstractmethod
|
| 42 |
+
def __call__(self, x: str) -> str:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Tagger(abc.ABC):
|
| 47 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
| 48 |
+
raise NotImplementedError
|
| 49 |
+
|
| 50 |
+
def get_name(self) -> str:
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
@abc.abstractmethod
|
| 54 |
+
def __call__(self, x: str) -> list[str]:
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MlRegistry:
|
| 59 |
def __init__(self) -> None:
|
| 60 |
+
self.processors: list[Processor] = []
|
| 61 |
+
self.summerizer: Summarizer | None = None
|
| 62 |
+
self.tagger: Tagger | None = None
|
| 63 |
+
self.model = None
|
| 64 |
+
self.tokenizer = None
|
| 65 |
+
|
| 66 |
+
def register_processor(self, processor: Processor) -> None:
|
| 67 |
+
self.processors.append(processor)
|
| 68 |
+
|
| 69 |
+
def register_summarizer(self, summarizer: Summarizer) -> None:
|
| 70 |
+
self.summerizer = summarizer
|
| 71 |
+
|
| 72 |
+
def register_tagger(self, tagger: Tagger) -> None:
|
| 73 |
+
self.tagger = tagger
|
| 74 |
+
|
| 75 |
+
def get_processor(self, input: JobInput) -> Processor:
|
| 76 |
+
assert self.processors
|
| 77 |
+
for processor in self.processors:
|
| 78 |
+
if processor.match(input):
|
| 79 |
+
return processor
|
| 80 |
+
|
| 81 |
+
return RawTextProcessor()
|
| 82 |
+
|
| 83 |
+
def get_summarizer(self) -> Summarizer:
|
| 84 |
+
assert self.summerizer
|
| 85 |
+
return self.summerizer
|
| 86 |
+
|
| 87 |
+
def get_tagger(self) -> Tagger:
|
| 88 |
+
assert self.tagger
|
| 89 |
+
return self.tagger
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class HfTransformersSummarizer(Summarizer):
|
| 93 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
| 94 |
+
self.model_name = model_name
|
| 95 |
+
self.model = model
|
| 96 |
+
self.tokenizer = tokenizer
|
| 97 |
+
self.generation_config = generation_config
|
| 98 |
+
|
| 99 |
self.template = "Summarize the text below in two sentences:\n\n{}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def __call__(self, x: str) -> str:
|
| 102 |
text = self.template.format(x)
|
| 103 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
| 104 |
+
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
|
| 105 |
+
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 106 |
assert isinstance(output, str)
|
| 107 |
return output
|
| 108 |
|
| 109 |
def get_name(self) -> str:
|
| 110 |
+
return f"{self.__class__.__name__}({self.model_name})"
|
| 111 |
|
| 112 |
|
| 113 |
+
class HfTransformersTagger(Tagger):
|
| 114 |
+
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
|
| 115 |
+
self.model_name = model_name
|
| 116 |
+
self.model = model
|
| 117 |
+
self.tokenizer = tokenizer
|
| 118 |
+
self.generation_config = generation_config
|
| 119 |
+
|
| 120 |
self.template = (
|
| 121 |
"Create a list of tags for the text below. The tags should be high level "
|
| 122 |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
|
| 123 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def _extract_tags(self, text: str) -> list[str]:
|
| 126 |
tags = set()
|
|
|
|
| 131 |
|
| 132 |
def __call__(self, x: str) -> list[str]:
|
| 133 |
text = self.template.format(x)
|
| 134 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
| 135 |
+
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
|
| 136 |
+
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 137 |
tags = self._extract_tags(output)
|
| 138 |
return tags
|
| 139 |
|
| 140 |
def get_name(self) -> str:
|
| 141 |
+
return f"{self.__class__.__name__}({self.model_name})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
+
class RawTextProcessor(Processor):
|
| 145 |
def match(self, input: JobInput) -> bool:
|
| 146 |
return True
|
| 147 |
|
| 148 |
def process(self, input: JobInput) -> str:
|
| 149 |
return input.content
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
class DefaultUrlProcessor(Processor):
|
| 153 |
def __init__(self) -> None:
|
| 154 |
self.client = httpx.Client()
|
| 155 |
self.regex = re.compile(r"(https?://[^\s]+)")
|
|
|
|
| 171 |
text = self.template.format(url=self.url, content=text)
|
| 172 |
return text
|
| 173 |
|
| 174 |
+
# class ProcessorRegistry:
|
| 175 |
+
# def __init__(self) -> None:
|
| 176 |
+
# self.registry: list[Processor] = []
|
| 177 |
+
# self.default_registry: list[Processor] = []
|
| 178 |
+
# self.set_default_processors()
|
| 179 |
|
| 180 |
+
# def set_default_processors(self) -> None:
|
| 181 |
+
# self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
|
| 182 |
|
| 183 |
+
# def register(self, processor: Processor) -> None:
|
| 184 |
+
# self.registry.append(processor)
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
# def dispatch(self, input: JobInput) -> Processor:
|
| 187 |
+
# for processor in self.registry + self.default_registry:
|
| 188 |
+
# if processor.match(input):
|
| 189 |
+
# return processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
+
# # should never be requires, but eh
|
| 192 |
+
# return RawProcessor()
|
src/webservice.py
CHANGED
|
@@ -14,6 +14,12 @@ logger.setLevel(logging.DEBUG)
|
|
| 14 |
app = FastAPI()
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@app.post("/submit/")
|
| 18 |
def submit_job(input: RequestInput) -> str:
|
| 19 |
# submit a new job, poor man's job queue
|
|
|
|
| 14 |
app = FastAPI()
|
| 15 |
|
| 16 |
|
| 17 |
+
# status
|
| 18 |
+
@app.get("/status/")
|
| 19 |
+
def status() -> str:
|
| 20 |
+
return "OK"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
@app.post("/submit/")
|
| 24 |
def submit_job(input: RequestInput) -> str:
|
| 25 |
# submit a new job, poor man's job queue
|
src/worker.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
| 1 |
import time
|
|
|
|
| 2 |
|
| 3 |
from base import JobInput
|
| 4 |
from db import get_db_cursor
|
| 5 |
-
from ml import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
SLEEP_INTERVAL = 5
|
| 8 |
|
| 9 |
|
| 10 |
-
processor_registry = ProcessorRegistry()
|
| 11 |
-
summarizer = Summarizer()
|
| 12 |
-
tagger = Tagger()
|
| 13 |
-
print("loaded ML models")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
def check_pending_jobs() -> list[JobInput]:
|
| 17 |
"""Check DB for pending jobs"""
|
| 18 |
with get_db_cursor() as cursor:
|
|
@@ -30,15 +31,38 @@ def check_pending_jobs() -> list[JobInput]:
|
|
| 30 |
]
|
| 31 |
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
with get_db_cursor() as cursor:
|
| 43 |
# write to entries, summary, tags tables
|
| 44 |
cursor.execute(
|
|
@@ -46,39 +70,23 @@ def store(
|
|
| 46 |
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
|
| 47 |
" VALUES (?, ?, ?)"
|
| 48 |
),
|
| 49 |
-
(job.id, summary, summarizer_name),
|
| 50 |
)
|
| 51 |
cursor.executemany(
|
| 52 |
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
|
| 53 |
-
[(job.id, tag, tagger_name) for tag in tags],
|
| 54 |
)
|
| 55 |
|
| 56 |
|
| 57 |
-
def process_job(job: JobInput) -> None:
|
| 58 |
tic = time.perf_counter()
|
| 59 |
print(f"Processing job for (id={job.id[:8]})")
|
| 60 |
|
| 61 |
# care: acquire cursor (which leads to locking) as late as possible, since
|
| 62 |
# the processing and we don't want to block other workers during that time
|
| 63 |
try:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
processed = processor(job)
|
| 67 |
-
|
| 68 |
-
tagger_name = tagger.get_name()
|
| 69 |
-
tags = tagger(processed)
|
| 70 |
-
|
| 71 |
-
summarizer_name = summarizer.get_name()
|
| 72 |
-
summary = summarizer(processed)
|
| 73 |
-
|
| 74 |
-
store(
|
| 75 |
-
job,
|
| 76 |
-
summary=summary,
|
| 77 |
-
tags=tags,
|
| 78 |
-
processor_name=processor_name,
|
| 79 |
-
summarizer_name=summarizer_name,
|
| 80 |
-
tagger_name=tagger_name,
|
| 81 |
-
)
|
| 82 |
# update job status to done
|
| 83 |
with get_db_cursor() as cursor:
|
| 84 |
cursor.execute(
|
|
@@ -96,7 +104,40 @@ def process_job(job: JobInput) -> None:
|
|
| 96 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
| 100 |
while True:
|
| 101 |
jobs = check_pending_jobs()
|
| 102 |
if not jobs:
|
|
@@ -106,7 +147,7 @@ def main() -> None:
|
|
| 106 |
|
| 107 |
print(f"Found {len(jobs)} pending job(s), processing...")
|
| 108 |
for job in jobs:
|
| 109 |
-
process_job(job)
|
| 110 |
|
| 111 |
|
| 112 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import time
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
|
| 4 |
from base import JobInput
|
| 5 |
from db import get_db_cursor
|
| 6 |
+
from ml import (
|
| 7 |
+
DefaultUrlProcessor,
|
| 8 |
+
HfTransformersSummarizer,
|
| 9 |
+
HfTransformersTagger,
|
| 10 |
+
MlRegistry,
|
| 11 |
+
RawTextProcessor,
|
| 12 |
+
)
|
| 13 |
|
| 14 |
SLEEP_INTERVAL = 5
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def check_pending_jobs() -> list[JobInput]:
|
| 18 |
"""Check DB for pending jobs"""
|
| 19 |
with get_db_cursor() as cursor:
|
|
|
|
| 31 |
]
|
| 32 |
|
| 33 |
|
| 34 |
+
@dataclass
|
| 35 |
+
class JobOutput:
|
| 36 |
+
summary: str
|
| 37 |
+
tags: list[str]
|
| 38 |
+
processor_name: str
|
| 39 |
+
summarizer_name: str
|
| 40 |
+
tagger_name: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
|
| 44 |
+
processor = registry.get_processor(job)
|
| 45 |
+
processor_name = processor.get_name()
|
| 46 |
+
processed = processor(job)
|
| 47 |
+
|
| 48 |
+
tagger = registry.get_tagger()
|
| 49 |
+
tagger_name = tagger.get_name()
|
| 50 |
+
tags = tagger(processed)
|
| 51 |
+
|
| 52 |
+
summarizer = registry.get_summarizer()
|
| 53 |
+
summarizer_name = summarizer.get_name()
|
| 54 |
+
summary = summarizer(processed)
|
| 55 |
+
|
| 56 |
+
return JobOutput(
|
| 57 |
+
summary=summary,
|
| 58 |
+
tags=tags,
|
| 59 |
+
processor_name=processor_name,
|
| 60 |
+
summarizer_name=summarizer_name,
|
| 61 |
+
tagger_name=tagger_name,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def store(job: JobInput, output: JobOutput) -> None:
|
| 66 |
with get_db_cursor() as cursor:
|
| 67 |
# write to entries, summary, tags tables
|
| 68 |
cursor.execute(
|
|
|
|
| 70 |
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
|
| 71 |
" VALUES (?, ?, ?)"
|
| 72 |
),
|
| 73 |
+
(job.id, output.summary, output.summarizer_name),
|
| 74 |
)
|
| 75 |
cursor.executemany(
|
| 76 |
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
|
| 77 |
+
[(job.id, tag, output.tagger_name) for tag in output.tags],
|
| 78 |
)
|
| 79 |
|
| 80 |
|
| 81 |
+
def process_job(job: JobInput, registry: MlRegistry) -> None:
|
| 82 |
tic = time.perf_counter()
|
| 83 |
print(f"Processing job for (id={job.id[:8]})")
|
| 84 |
|
| 85 |
# care: acquire cursor (which leads to locking) as late as possible, since
|
| 86 |
# the processing and we don't want to block other workers during that time
|
| 87 |
try:
|
| 88 |
+
output = _process_job(job, registry)
|
| 89 |
+
store(job, output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
# update job status to done
|
| 91 |
with get_db_cursor() as cursor:
|
| 92 |
cursor.execute(
|
|
|
|
| 104 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
| 105 |
|
| 106 |
|
| 107 |
+
def load_mlregistry(model_name: str) -> MlRegistry:
|
| 108 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
| 109 |
+
|
| 110 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 111 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 112 |
+
|
| 113 |
+
config_summarizer = GenerationConfig.from_pretrained(model_name)
|
| 114 |
+
config_summarizer.max_new_tokens = 200
|
| 115 |
+
config_summarizer.min_new_tokens = 100
|
| 116 |
+
config_summarizer.top_k = 5
|
| 117 |
+
config_summarizer.repetition_penalty = 1.5
|
| 118 |
+
|
| 119 |
+
config_tagger = GenerationConfig.from_pretrained(model_name)
|
| 120 |
+
config_tagger.max_new_tokens = 50
|
| 121 |
+
config_tagger.min_new_tokens = 25
|
| 122 |
+
# increase the temperature to make the model more creative
|
| 123 |
+
config_tagger.temperature = 1.5
|
| 124 |
+
|
| 125 |
+
summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
|
| 126 |
+
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
| 127 |
+
|
| 128 |
+
registry = MlRegistry()
|
| 129 |
+
registry.register_processor(DefaultUrlProcessor())
|
| 130 |
+
registry.register_processor(RawTextProcessor())
|
| 131 |
+
registry.register_summarizer(summarizer)
|
| 132 |
+
registry.register_tagger(tagger)
|
| 133 |
+
|
| 134 |
+
return registry
|
| 135 |
+
|
| 136 |
+
|
| 137 |
def main() -> None:
|
| 138 |
+
model_name = "google/flan-t5-large"
|
| 139 |
+
registry = load_mlregistry(model_name)
|
| 140 |
+
|
| 141 |
while True:
|
| 142 |
jobs = check_pending_jobs()
|
| 143 |
if not jobs:
|
|
|
|
| 147 |
|
| 148 |
print(f"Found {len(jobs)} pending job(s), processing...")
|
| 149 |
for job in jobs:
|
| 150 |
+
process_job(job, registry)
|
| 151 |
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|