arshadrana commited on
Commit
c5ad565
·
verified ·
1 Parent(s): 1af5b23

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from datasets import load_dataset
4
+
5
+ # Load the dataset
6
+ dataset = load_dataset('your_dataset_name')
7
+
8
+ # Initialize the model and processor
9
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
+
12
+ # Prepare the dataset for training
13
+ def preprocess_data(example):
14
+ pixel_values = processor(images=example['image'], return_tensors="pt").pixel_values
15
+ labels = processor(text=example['text'], return_tensors="pt").input_ids
16
+ return {'pixel_values': pixel_values, 'labels': labels}
17
+
18
+ train_dataset = dataset['train'].map(preprocess_data)
19
+
20
+ # Fine-tune the model
21
+ training_args = {
22
+ 'per_device_train_batch_size': 8,
23
+ 'num_train_epochs': 3,
24
+ 'logging_steps': 100,
25
+ 'save_steps': 500,
26
+ 'evaluation_strategy': 'steps',
27
+ }
28
+
29
+ trainer = Trainer(
30
+ model=model,
31
+ args=training_args,
32
+ train_dataset=train_dataset,
33
+ )
34
+
35
+ trainer.train()