arshadrana commited on
Commit
211d8d1
·
verified ·
1 Parent(s): c5ad565

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -1,35 +1,44 @@
1
  import torch
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from datasets import load_dataset
4
 
5
- # Load the dataset
6
- dataset = load_dataset('your_dataset_name')
7
 
8
- # Initialize the model and processor
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
- # Prepare the dataset for training
13
  def preprocess_data(example):
 
14
  pixel_values = processor(images=example['image'], return_tensors="pt").pixel_values
15
  labels = processor(text=example['text'], return_tensors="pt").input_ids
16
  return {'pixel_values': pixel_values, 'labels': labels}
17
 
 
18
  train_dataset = dataset['train'].map(preprocess_data)
19
 
20
- # Fine-tune the model
21
- training_args = {
22
- 'per_device_train_batch_size': 8,
23
- 'num_train_epochs': 3,
24
- 'logging_steps': 100,
25
- 'save_steps': 500,
26
- 'evaluation_strategy': 'steps',
27
- }
 
28
 
 
29
  trainer = Trainer(
30
  model=model,
31
  args=training_args,
32
  train_dataset=train_dataset,
33
  )
34
 
 
35
  trainer.train()
 
 
 
 
 
1
  import torch
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments
3
  from datasets import load_dataset
4
 
5
+ # Load your dataset
6
+ dataset = load_dataset('your_dataset_name') # Replace with your dataset name
7
 
8
+ # Initialize the processor and model
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
 
12
+ # Preprocess the data
13
  def preprocess_data(example):
14
+ # Process images and texts
15
  pixel_values = processor(images=example['image'], return_tensors="pt").pixel_values
16
  labels = processor(text=example['text'], return_tensors="pt").input_ids
17
  return {'pixel_values': pixel_values, 'labels': labels}
18
 
19
+ # Map preprocessing to the train dataset
20
  train_dataset = dataset['train'].map(preprocess_data)
21
 
22
+ # Training arguments
23
+ training_args = TrainingArguments(
24
+ output_dir='./results',
25
+ per_device_train_batch_size=8,
26
+ num_train_epochs=3,
27
+ logging_steps=100,
28
+ save_steps=500,
29
+ evaluation_strategy='steps',
30
+ )
31
 
32
+ # Trainer setup
33
  trainer = Trainer(
34
  model=model,
35
  args=training_args,
36
  train_dataset=train_dataset,
37
  )
38
 
39
+ # Train the model
40
  trainer.train()
41
+
42
+ # Save the model and processor after training
43
+ model.save_pretrained('./your_model_name')
44
+ processor.save_pretrained('./your_model_name')