Vik Paruchuri commited on
Commit
e704210
·
1 Parent(s): ec69c20

Add benchmark results

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
  # Marker
2
 
3
- Marker converts PDF, EPUB, and MOBI to markdown. It's 10x faster than nougat, more accurate on most documents, and has near-zero hallucination risk.
4
 
5
  - Support for a range of PDF documents (optimized for books and scientific papers)
6
  - Removes headers/footers/other artifacts
7
  - Converts most equations to latex
8
  - Formats code blocks and tables
9
- - Support for multiple languages (although most testing is done in English). See `settings.py` for a list of supported languages.
10
  - Works on GPU, CPU, or MPS
11
 
12
  ## How it works
@@ -16,23 +16,29 @@ Marker is a pipeline of deep learning models:
16
  - Extract text, OCR if necessary (heuristics, tesseract)
17
  - Detect page layout ([layout segmenter](https://huggingface.co/vikp/layout_segmenter), [column detector](https://huggingface.co/vikp/column_detector))
18
  - Clean and format each block (heuristics, [nougat](https://huggingface.co/facebook/nougat-base))
19
- - Combine blocks and postprocess complete text (heuristics, [pdf_postprocessor](https://huggingface.co/vikp/pdf_postprocessor))
20
 
21
- Relying on autoregressive forward passes to generate text is slow and prone to hallucination/repetition. From the nougat paper `We observed [repetition] in 1.5% of pages in the test set, but the frequency increases for out-of-domain documents.` In my anecdotal testing, repetitions happen on 5%+ of out-of-domain (non-arXiv) pages. Nougat is an amazing model that is part of marker, it's just not a general-purpose converter.
22
 
23
- Marker is 10x faster and more accurate by only passing equation blocks through an LLM forward pass.
24
 
25
  ## Examples
26
 
27
- | PDF | Type | Marker | Nougat |
28
- |-----------------------------------------------------------------------|-------------|--------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------|
29
- | [Think Python](https://greenteapress.com/thinkpython/thinkpython.pdf) | Textbook | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/marker/thinkpython.md) | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/nougat/thinkpython.md) |
30
- | [Think OS](https://greenteapress.com/thinkos/thinkos.pdf) | Textbook | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/marker/thinkos.md) | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/nougat/thinkos.md) |
31
- | [Switch Transformers](https://arxiv.org/pdf/2101.03961.pdf) | arXiv paper | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/marker/switch_transformers.md) | [View](https://github.com/VikParuchuri/marker/blob/master/examples/nougat/switch_transformers.md) |
32
- | [Multi-column CNN](https://arxiv.org/pdf/1804.07821.pdf) | arXiv paper | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/marker/multicolcnn.md) | [View file](https://github.com/VikParuchuri/marker/blob/master/examples/nougat/multicolcnn.md) |
33
 
34
 
35
- See [below](#benchmarks) for speed and accuracy benchmarks.
 
 
 
 
 
 
36
 
37
  # Installation
38
 
@@ -71,11 +77,12 @@ First, clone the repo:
71
 
72
  # Usage
73
 
74
- **Configuration**
75
 
76
  - Set your torch device in the `local.env` file. For example, `TORCH_DEVICE=cuda` or `TORCH_DEVICE=mps`. `cpu` is the default.
77
  - If using GPU, set `INFERENCE_RAM` to your GPU VRAM (per GPU). For example, if you have 16 GB of VRAM, set `INFERENCE_RAM=16`.
78
  - Depending on your document types, marker's average memory usage per task can vary slightly. You can configure `VRAM_PER_TASK` to adjust this if you notice tasks failing with GPU out of memory errors.
 
79
  - Inspect the settings in `marker/settings.py`. You can override any settings in the `local.env` file, or by setting environment variables.
80
 
81
  ## Convert a single file
@@ -96,7 +103,7 @@ Make sure the `DEFAULT_LANG` setting is set appropriately for your document.
96
  Run `convert.py`, like this:
97
 
98
  ```
99
- python convert.py /path/to/input/folder /path/to/output/folder --workers 4 --max 10 --metadata_file /path/to/metadata.json
100
  ```
101
 
102
  - `--workers` is the number of pdfs to convert at once. This is set to 1 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Parallelism will not increase beyond `INFERENCE_RAM / VRAM_PER_TASK` if you're using GPU.
@@ -116,7 +123,7 @@ python convert.py /path/to/input/folder /path/to/output/folder --workers 4 --max
116
  Run `chunk_convert.sh`, like this:
117
 
118
  ```
119
- METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=35 bash chunk_convert.sh ../pdf_in ../md_out
120
  ```
121
 
122
  - `METADATA_FILE` is an optional path to a json file with metadata about the pdfs. See above for the format.
@@ -127,28 +134,34 @@ METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=35 bash chunk_convert.s
127
 
128
  Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods.
129
 
130
- Benchmarks show that marker is 10x faster than nougat, and more accurate outside arXiv (nougat was trained on arXiv data).
131
 
132
  **Speed**
133
 
134
- Method Average Score Time per doc
135
- -------- --------------- --------------
136
- naive 0.351585 0.328931
137
- marker 0.636839 78.1468
138
- nougat 0.614548 810.756
139
 
140
  **Accuracy**
141
 
142
  First 3 are non-arXiv books, last 3 are arXiv papers.
143
 
144
- Method thinkos.pdf thinkdsp.pdf thinkpython.pdf switch_trans.pdf crowd.pdf multicolcnn.pdf
145
- -------- ------------- -------------- ----------------- ------------------ ----------- -----------------
146
- naive 0.366817 0.412014 0.468147 0.244739 0.14489 0.0890217
147
- marker 0.753291 0.787938 0.779262 0.478387 0.446068 0.533737
148
- nougat 0.638434 0.632723 0.637626 0.690028 0.540994 0.699539
149
 
150
  Peak GPU memory usage during the benchmark is `3.3GB` for nougat, and `3.1GB` for marker. Benchmarks were run on an A6000.
151
 
 
 
 
 
 
 
152
  ## Running your own benchmarks
153
 
154
  You can benchmark the performance of marker on your machine. First, download the benchmark data [here](https://drive.google.com/file/d/1WiN4K2-jQfwyQMe4wSSurbpz3hxo2fG9/view?usp=drive_link) and unzip.
 
1
  # Marker
2
 
3
+ Marker converts PDF, EPUB, and MOBI to markdown. It's 10x faster than nougat, more accurate on most documents, and has low hallucination risk.
4
 
5
  - Support for a range of PDF documents (optimized for books and scientific papers)
6
  - Removes headers/footers/other artifacts
7
  - Converts most equations to latex
8
  - Formats code blocks and tables
9
+ - Support for multiple languages (although most testing is done in English). See `settings.py` for a language list.
10
  - Works on GPU, CPU, or MPS
11
 
12
  ## How it works
 
16
  - Extract text, OCR if necessary (heuristics, tesseract)
17
  - Detect page layout ([layout segmenter](https://huggingface.co/vikp/layout_segmenter), [column detector](https://huggingface.co/vikp/column_detector))
18
  - Clean and format each block (heuristics, [nougat](https://huggingface.co/facebook/nougat-base))
19
+ - Combine blocks and postprocess complete text (heuristics, [pdf_postprocessor](https://huggingface.co/vikp/pdf_postprocessor_t5))
20
 
21
+ Relying on autoregressive forward passes to generate text is slow and prone to hallucination/repetition. From the nougat paper: `We observed [repetition] in 1.5% of pages in the test set, but the frequency increases for out-of-domain documents.` In my anecdotal testing, repetitions happen on 5%+ of out-of-domain (non-arXiv) pages.
22
 
23
+ Nougat is an amazing model, but I wanted a faster and more general purpose solution. Marker is 10x faster and has low hallucination risk because it only passes equation blocks through an LLM forward pass.
24
 
25
  ## Examples
26
 
27
+ | PDF | Type | Marker | Nougat |
28
+ |-----------------------------------------------------------------------|-------------|--------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------|
29
+ | [Think Python](https://greenteapress.com/thinkpython/thinkpython.pdf) | Textbook | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/marker/thinkpython.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/nougat/thinkpython.md) |
30
+ | [Think OS](https://greenteapress.com/thinkos/thinkos.pdf) | Textbook | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/marker/thinkos.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/nougat/thinkos.md) |
31
+ | [Switch Transformers](https://arxiv.org/pdf/2101.03961.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/marker/switch_transformers.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/nougat/switch_transformers.md) |
32
+ | [Multi-column CNN](https://arxiv.org/pdf/1804.07821.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/marker/multicolcnn.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/nougat/multicolcnn.md) |
33
 
34
 
35
+ ## Performance
36
+
37
+ ![Benchmark overall](data/images/overall.png)
38
+
39
+ The above results are with marker and nougat setup so they each take ~3GB of VRAM on an A6000.
40
+
41
+ See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instructions on how to run your own benchmarks.
42
 
43
  # Installation
44
 
 
77
 
78
  # Usage
79
 
80
+ First, some configuration:
81
 
82
  - Set your torch device in the `local.env` file. For example, `TORCH_DEVICE=cuda` or `TORCH_DEVICE=mps`. `cpu` is the default.
83
  - If using GPU, set `INFERENCE_RAM` to your GPU VRAM (per GPU). For example, if you have 16 GB of VRAM, set `INFERENCE_RAM=16`.
84
  - Depending on your document types, marker's average memory usage per task can vary slightly. You can configure `VRAM_PER_TASK` to adjust this if you notice tasks failing with GPU out of memory errors.
85
+ - By default, the final editor model is off. Turn it on with `ENABLE_EDITOR_MODEL`.
86
  - Inspect the settings in `marker/settings.py`. You can override any settings in the `local.env` file, or by setting environment variables.
87
 
88
  ## Convert a single file
 
103
  Run `convert.py`, like this:
104
 
105
  ```
106
+ python convert.py /path/to/input/folder /path/to/output/folder --workers 10 --max 10 --metadata_file /path/to/metadata.json
107
  ```
108
 
109
  - `--workers` is the number of pdfs to convert at once. This is set to 1 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Parallelism will not increase beyond `INFERENCE_RAM / VRAM_PER_TASK` if you're using GPU.
 
123
  Run `chunk_convert.sh`, like this:
124
 
125
  ```
126
+ METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=15 bash chunk_convert.sh ../pdf_in ../md_out
127
  ```
128
 
129
  - `METADATA_FILE` is an optional path to a json file with metadata about the pdfs. See above for the format.
 
134
 
135
  Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods.
136
 
137
+ Benchmarks show that marker is 10x faster than nougat, and more accurate outside arXiv (nougat was trained on arXiv data). We show naive text extraction (pulling text out of the pdf with no processing) for comparison.
138
 
139
  **Speed**
140
 
141
+ | Method | Average Score | Time per page | Time per document |
142
+ |--------|---------------|---------------|-------------------|
143
+ | naive | 0.350727 | 0.00152378 | 0.326524 |
144
+ | marker | 0.641062 | 0.360622 | 77.2762 |
145
+ | nougat | 0.629211 | 3.77259 | 808.413 |
146
 
147
  **Accuracy**
148
 
149
  First 3 are non-arXiv books, last 3 are arXiv papers.
150
 
151
+ | Method | switch_trans.pdf | crowd.pdf | multicolcnn.pdf | thinkos.pdf | thinkdsp.pdf | thinkpython.pdf |
152
+ |--------|------------------|-----------|-----------------|-------------|--------------|-----------------|
153
+ | naive | 0.244114 | 0.140669 | 0.0868221 | 0.366856 | 0.412521 | 0.468281 |
154
+ | marker | 0.482091 | 0.466882 | 0.537062 | 0.754347 | 0.78825 | 0.779536 |
155
+ | nougat | 0.696458 | 0.552337 | 0.735099 | 0.655002 | 0.645704 | 0.650282 |
156
 
157
  Peak GPU memory usage during the benchmark is `3.3GB` for nougat, and `3.1GB` for marker. Benchmarks were run on an A6000.
158
 
159
+ **Throughput**
160
+
161
+ Marker takes about 2GB of VRAM on average per task, so you can convert 24 documents in parallel on an A6000.
162
+
163
+ ![Benchmark results](data/images/per_doc.png)
164
+
165
  ## Running your own benchmarks
166
 
167
  You can benchmark the performance of marker on your machine. First, download the benchmark data [here](https://drive.google.com/file/d/1WiN4K2-jQfwyQMe4wSSurbpz3hxo2fG9/view?usp=drive_link) and unzip.
convert.py CHANGED
@@ -3,15 +3,11 @@ import os
3
  from typing import Dict
4
 
5
  import ray
6
- import torch
7
  from tqdm import tqdm
8
  import math
9
 
10
  from marker.convert import convert_single_pdf, get_length_of_text
11
  from marker.models import load_all_models
12
- from marker.ordering import load_ordering_model
13
- from marker.segmentation import load_layout_model
14
- from marker.cleaners.equations import load_nougat_model
15
  from marker.settings import settings
16
  from marker.logger import configure_logging
17
  import traceback
 
3
  from typing import Dict
4
 
5
  import ray
 
6
  from tqdm import tqdm
7
  import math
8
 
9
  from marker.convert import convert_single_pdf, get_length_of_text
10
  from marker.models import load_all_models
 
 
 
11
  from marker.settings import settings
12
  from marker.logger import configure_logging
13
  import traceback
marker/cleaners/table.py CHANGED
@@ -80,7 +80,7 @@ def create_new_tables(blocks: List[Page]):
80
  if max([len("".join(r)) for r in table_rows]) > 300 or len(table_rows[0]) > 8:
81
  continue
82
 
83
- new_text = tabulate(table_rows, headers="firstrow", tablefmt="simple")
84
  new_span = Span(
85
  bbox=block.bbox,
86
  span_id=f"{table_idx}_fix_table",
 
80
  if max([len("".join(r)) for r in table_rows]) > 300 or len(table_rows[0]) > 8:
81
  continue
82
 
83
+ new_text = tabulate(table_rows, headers="firstrow", tablefmt="github")
84
  new_span = Span(
85
  bbox=block.bbox,
86
  span_id=f"{table_idx}_fix_table",
marker/ordering.py CHANGED
@@ -16,9 +16,11 @@ processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME)
16
 
17
 
18
  def load_ordering_model():
19
- model = LayoutLMv3ForSequenceClassification.from_pretrained(settings.ORDERER_MODEL_NAME).to(settings.TORCH_DEVICE)
20
- if settings.CUDA:
21
- model = model.to(torch.bfloat16)
 
 
22
  return model
23
 
24
 
 
16
 
17
 
18
  def load_ordering_model():
19
+ model = LayoutLMv3ForSequenceClassification.from_pretrained(
20
+ settings.ORDERER_MODEL_NAME,
21
+ torch_dtype=settings.MODEL_DTYPE,
22
+ ).to(settings.TORCH_DEVICE)
23
+ model.eval()
24
  return model
25
 
26
 
marker/postprocessors/editor.py CHANGED
@@ -1,7 +1,6 @@
1
  from collections import defaultdict, Counter
2
  from itertools import chain
3
  from typing import Optional
4
- import re
5
 
6
  from transformers import AutoTokenizer
7
  from marker.settings import settings
@@ -17,11 +16,10 @@ def load_editing_model():
17
  return None
18
 
19
  model = T5ForTokenClassification.from_pretrained(
20
- settings.EDITOR_MODEL_NAME
 
21
  ).to(settings.TORCH_DEVICE)
22
-
23
- if settings.CUDA:
24
- model = model.to(torch.bfloat16)
25
 
26
  model.config.label2id = {
27
  "equal": 0,
@@ -41,15 +39,6 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
41
  input_ids = tokenized["input_ids"]
42
  char_token_lengths = tokenized["char_token_lengths"]
43
 
44
- # Tokenize, and make sure reverse tokenization works
45
- model_tokens = [tokenizer.convert_ids_to_tokens(t, skip_special_tokens=True) for t in input_ids]
46
- model_tokens_str = [tokenizer.convert_tokens_to_string(t) for t in model_tokens]
47
- full_text = "".join(model_tokens_str)
48
- assert full_text == text
49
-
50
- # List of characters in the text
51
- flat_input_ids = list(chain.from_iterable(input_ids))
52
-
53
  # Run model
54
  token_masks = []
55
  for i in range(0, len(input_ids), batch_size):
@@ -67,14 +56,17 @@ def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_s
67
  probs = F.softmax(logits, dim=-1)
68
  max_prob = torch.max(probs, dim=-1)
69
  cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
70
- labels = logits.argmax(-1).squeeze()
71
  labels[cutoff_prob] = model.config.label2id["equal"]
72
- labels = labels.tolist()
73
  if len(labels) == settings.EDITOR_MAX_LENGTH:
74
  labels = [labels]
75
  labels = list(chain.from_iterable(labels))
76
  token_masks.extend(labels)
77
 
 
 
 
78
  # Strip special tokens 0,1. Keep unknown token, although it should never be used
79
  assert len(token_masks) == len(flat_input_ids)
80
  token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
 
1
  from collections import defaultdict, Counter
2
  from itertools import chain
3
  from typing import Optional
 
4
 
5
  from transformers import AutoTokenizer
6
  from marker.settings import settings
 
16
  return None
17
 
18
  model = T5ForTokenClassification.from_pretrained(
19
+ settings.EDITOR_MODEL_NAME,
20
+ torch_dtype=settings.MODEL_DTYPE,
21
  ).to(settings.TORCH_DEVICE)
22
+ model.eval()
 
 
23
 
24
  model.config.label2id = {
25
  "equal": 0,
 
39
  input_ids = tokenized["input_ids"]
40
  char_token_lengths = tokenized["char_token_lengths"]
41
 
 
 
 
 
 
 
 
 
 
42
  # Run model
43
  token_masks = []
44
  for i in range(0, len(input_ids), batch_size):
 
56
  probs = F.softmax(logits, dim=-1)
57
  max_prob = torch.max(probs, dim=-1)
58
  cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
59
+ labels = logits.argmax(-1)
60
  labels[cutoff_prob] = model.config.label2id["equal"]
61
+ labels = labels.squeeze().tolist()
62
  if len(labels) == settings.EDITOR_MAX_LENGTH:
63
  labels = [labels]
64
  labels = list(chain.from_iterable(labels))
65
  token_masks.extend(labels)
66
 
67
+ # List of characters in the text
68
+ flat_input_ids = list(chain.from_iterable(input_ids))
69
+
70
  # Strip special tokens 0,1. Keep unknown token, although it should never be used
71
  assert len(token_masks) == len(flat_input_ids)
72
  token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
marker/segmentation.py CHANGED
@@ -24,9 +24,10 @@ NO_CHUNK_KEYS = ["pixel_values"]
24
 
25
 
26
  def load_layout_model():
27
- model = LayoutLMv3ForTokenClassification.from_pretrained(settings.LAYOUT_MODEL_NAME).to(settings.TORCH_DEVICE)
28
- if settings.CUDA:
29
- model = model.to(torch.bfloat16)
 
30
 
31
  model.config.id2label = {
32
  0: "Caption",
 
24
 
25
 
26
  def load_layout_model():
27
+ model = LayoutLMv3ForTokenClassification.from_pretrained(
28
+ settings.LAYOUT_MODEL_NAME,
29
+ torch_dtype=settings.MODEL_DTYPE,
30
+ ).to(settings.TORCH_DEVICE)
31
 
32
  model.config.id2label = {
33
  0: "Caption",
marker/settings.py CHANGED
@@ -5,13 +5,14 @@ from dotenv import find_dotenv
5
  from pydantic import computed_field
6
  from pydantic_settings import BaseSettings
7
  import fitz as pymupdf
 
8
 
9
 
10
  class Settings(BaseSettings):
11
  # General
12
  TORCH_DEVICE: str = "cpu"
13
  INFERENCE_RAM: int = 40 # How much VRAM each GPU has (in GB).
14
- VRAM_PER_TASK: float = 2.5 # How much VRAM to allocate per task (in GB). Peak marker VRAM usage is around 3GB, but avg across workers is lower.
15
  DEBUG: bool = False # Enable debug logging
16
  DEFAULT_LANG: str = "English" # Default language we assume files to be in, should be one of the keys in TESSERACT_LANGUAGES
17
 
@@ -73,10 +74,10 @@ class Settings(BaseSettings):
73
 
74
  # Final editing model
75
  EDITOR_BATCH_SIZE: int = 4
76
- EDITOR_MAX_LENGTH: int = 2048
77
  EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
78
- ENABLE_EDITOR_MODEL: bool = True # The editor model can create false positives
79
- EDITOR_CUTOFF_THRESH: float = 0.75 # Ignore predictions below this probability
80
 
81
  # Ray
82
  RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
@@ -88,6 +89,11 @@ class Settings(BaseSettings):
88
  def CUDA(self) -> bool:
89
  return "cuda" in self.TORCH_DEVICE
90
 
 
 
 
 
 
91
  class Config:
92
  env_file = find_dotenv("local.env")
93
  extra = "ignore"
 
5
  from pydantic import computed_field
6
  from pydantic_settings import BaseSettings
7
  import fitz as pymupdf
8
+ import torch
9
 
10
 
11
  class Settings(BaseSettings):
12
  # General
13
  TORCH_DEVICE: str = "cpu"
14
  INFERENCE_RAM: int = 40 # How much VRAM each GPU has (in GB).
15
+ VRAM_PER_TASK: float = 2 # How much VRAM to allocate per task (in GB). Peak marker VRAM usage is around 3GB, but avg across workers is lower.
16
  DEBUG: bool = False # Enable debug logging
17
  DEFAULT_LANG: str = "English" # Default language we assume files to be in, should be one of the keys in TESSERACT_LANGUAGES
18
 
 
74
 
75
  # Final editing model
76
  EDITOR_BATCH_SIZE: int = 4
77
+ EDITOR_MAX_LENGTH: int = 1024
78
  EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
79
+ ENABLE_EDITOR_MODEL: bool = False # The editor model can create false positives
80
+ EDITOR_CUTOFF_THRESH: float = 0.9 # Ignore predictions below this probability
81
 
82
  # Ray
83
  RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
 
89
  def CUDA(self) -> bool:
90
  return "cuda" in self.TORCH_DEVICE
91
 
92
+ @computed_field
93
+ @property
94
+ def MODEL_DTYPE(self) -> torch.dtype:
95
+ return torch.bfloat16 if self.CUDA else torch.float32
96
+
97
  class Config:
98
  env_file = find_dotenv("local.env")
99
  extra = "ignore"