Spaces:
Runtime error
Runtime error
| import time | |
| from dataclasses import dataclass | |
| from base import JobInput | |
| from db import get_db_cursor | |
| from ml import ( | |
| DefaultUrlProcessor, | |
| HfTransformersSummarizer, | |
| HfTransformersTagger, | |
| MlRegistry, | |
| RawTextProcessor, | |
| ) | |
| SLEEP_INTERVAL = 5 | |
| def check_pending_jobs() -> list[JobInput]: | |
| """Check DB for pending jobs""" | |
| with get_db_cursor() as cursor: | |
| # fetch pending jobs, join authro and content from entries table | |
| query = """ | |
| SELECT j.entry_id, e.author, e.source | |
| FROM jobs j | |
| JOIN entries e | |
| ON j.entry_id = e.id | |
| WHERE j.status = 'pending' | |
| """ | |
| res = list(cursor.execute(query)) | |
| return [ | |
| JobInput(id=_id, author=author, content=content) for _id, author, content in res | |
| ] | |
| class JobOutput: | |
| summary: str | |
| tags: list[str] | |
| processor_name: str | |
| summarizer_name: str | |
| tagger_name: str | |
| def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput: | |
| processor = registry.get_processor(job) | |
| processor_name = processor.get_name() | |
| processed = processor(job) | |
| tagger = registry.get_tagger() | |
| tagger_name = tagger.get_name() | |
| tags = tagger(processed) | |
| summarizer = registry.get_summarizer() | |
| summarizer_name = summarizer.get_name() | |
| summary = summarizer(processed) | |
| return JobOutput( | |
| summary=summary, | |
| tags=tags, | |
| processor_name=processor_name, | |
| summarizer_name=summarizer_name, | |
| tagger_name=tagger_name, | |
| ) | |
| def store(job: JobInput, output: JobOutput) -> None: | |
| with get_db_cursor() as cursor: | |
| # write to entries, summary, tags tables | |
| cursor.execute( | |
| ( | |
| "INSERT INTO summaries (entry_id, summary, summarizer_name)" | |
| " VALUES (?, ?, ?)" | |
| ), | |
| (job.id, output.summary, output.summarizer_name), | |
| ) | |
| cursor.executemany( | |
| "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)", | |
| [(job.id, tag, output.tagger_name) for tag in output.tags], | |
| ) | |
| def process_job(job: JobInput, registry: MlRegistry) -> None: | |
| tic = time.perf_counter() | |
| print(f"Processing job for (id={job.id[:8]})") | |
| # care: acquire cursor (which leads to locking) as late as possible, since | |
| # the processing and we don't want to block other workers during that time | |
| try: | |
| output = _process_job(job, registry) | |
| store(job, output) | |
| # update job status to done | |
| with get_db_cursor() as cursor: | |
| cursor.execute( | |
| "UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,) | |
| ) | |
| except Exception as e: | |
| # update job status to failed | |
| with get_db_cursor() as cursor: | |
| cursor.execute( | |
| "UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,) | |
| ) | |
| print(f"Failed to process job for (id={job.id[:8]}): {e}") | |
| toc = time.perf_counter() | |
| print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds") | |
| def load_mlregistry(model_name: str) -> MlRegistry: | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| config_summarizer = GenerationConfig.from_pretrained(model_name) | |
| config_summarizer.max_new_tokens = 200 | |
| config_summarizer.min_new_tokens = 100 | |
| config_summarizer.top_k = 5 | |
| config_summarizer.repetition_penalty = 1.5 | |
| config_tagger = GenerationConfig.from_pretrained(model_name) | |
| config_tagger.max_new_tokens = 50 | |
| config_tagger.min_new_tokens = 25 | |
| # increase the temperature to make the model more creative | |
| config_tagger.temperature = 1.5 | |
| summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer) | |
| tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger) | |
| registry = MlRegistry() | |
| registry.register_processor(DefaultUrlProcessor()) | |
| registry.register_processor(RawTextProcessor()) | |
| registry.register_summarizer(summarizer) | |
| registry.register_tagger(tagger) | |
| return registry | |
| def main() -> None: | |
| model_name = "google/flan-t5-large" | |
| registry = load_mlregistry(model_name) | |
| while True: | |
| jobs = check_pending_jobs() | |
| if not jobs: | |
| print("No pending jobs found, sleeping...") | |
| time.sleep(SLEEP_INTERVAL) | |
| continue | |
| print(f"Found {len(jobs)} pending job(s), processing...") | |
| for job in jobs: | |
| process_job(job, registry) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print("Shutting down...") | |
| exit(0) | |