Update app.py
Browse files
app.py
CHANGED
|
@@ -35,16 +35,16 @@ dataset = project.version(1).download("folder")
|
|
| 35 |
subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt'])
|
| 36 |
subprocess.run(['unzip', 'filetxt'])
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
def download_and_unzip(url, save_path):
|
| 50 |
print(f"Downloading and extracting assets....", end="")
|
|
@@ -164,7 +164,7 @@ valid_dataset = CustomOCRDataset(
|
|
| 164 |
)
|
| 165 |
|
| 166 |
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
|
| 167 |
-
|
| 168 |
print(model)
|
| 169 |
# Total parameters and trainable parameters.
|
| 170 |
total_params = sum(p.numel() for p in model.parameters())
|
|
@@ -237,8 +237,7 @@ trainer = Seq2SeqTrainer(
|
|
| 237 |
res = trainer.train()
|
| 238 |
|
| 239 |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
|
| 240 |
-
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))
|
| 241 |
-
#.to(device)
|
| 242 |
|
| 243 |
def read_and_show(image_path):
|
| 244 |
"""
|
|
@@ -262,8 +261,7 @@ def ocr(image, processor, model):
|
|
| 262 |
generated_text: the OCR'd text string.
|
| 263 |
"""
|
| 264 |
# We can directly perform OCR on cropped images.
|
| 265 |
-
pixel_values = processor(image, return_tensors='pt').pixel_values
|
| 266 |
-
#.to(device)
|
| 267 |
generated_ids = model.generate(pixel_values)
|
| 268 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 269 |
return generated_text
|
|
|
|
| 35 |
subprocess.run(['wget', '--no-check-certificate', 'https://docs.google.com/uc?export=download&id=12reT7rxiRqTERYqeKYx7WGz5deMXjnEo', '-O', 'filetxt'])
|
| 36 |
subprocess.run(['unzip', 'filetxt'])
|
| 37 |
|
| 38 |
+
def seed_everything(seed_value):
|
| 39 |
+
np.random.seed(seed_value)
|
| 40 |
+
torch.manual_seed(seed_value)
|
| 41 |
+
torch.cuda.manual_seed_all(seed_value)
|
| 42 |
+
torch.backends.cudnn.deterministic = True
|
| 43 |
+
torch.backends.cudnn.benchmark = False
|
| 44 |
|
| 45 |
+
seed_everything(42)
|
| 46 |
|
| 47 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 48 |
|
| 49 |
def download_and_unzip(url, save_path):
|
| 50 |
print(f"Downloading and extracting assets....", end="")
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
|
| 167 |
+
model.to(device)
|
| 168 |
print(model)
|
| 169 |
# Total parameters and trainable parameters.
|
| 170 |
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
| 237 |
res = trainer.train()
|
| 238 |
|
| 239 |
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
|
| 240 |
+
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
|
|
|
|
| 241 |
|
| 242 |
def read_and_show(image_path):
|
| 243 |
"""
|
|
|
|
| 261 |
generated_text: the OCR'd text string.
|
| 262 |
"""
|
| 263 |
# We can directly perform OCR on cropped images.
|
| 264 |
+
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
|
|
|
|
| 265 |
generated_ids = model.generate(pixel_values)
|
| 266 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 267 |
return generated_text
|