kirchik47 commited on
Commit
dcc8a5e
·
1 Parent(s): 2c73b5a

Model code modifications

Browse files
custom_got/modeling_GOT.py CHANGED
@@ -558,7 +558,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
558
 
559
  image_tensor_1 = image_processor_high(image)
560
 
561
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
562
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
@@ -566,7 +566,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
@@ -578,7 +578,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
- with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
 
558
 
559
  image_tensor_1 = image_processor_high(image)
560
 
561
+ input_ids = torch.as_tensor(inputs.input_ids)
562
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
 
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
+ with torch.autocast("cpu", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
 
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
+ with torch.autocast("cpu", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
dataset_creation.py CHANGED
@@ -3,11 +3,11 @@ import json
3
  import os
4
 
5
 
6
- dataset = pd.read_csv('ocr_task/data_80k/data.csv')
7
  labels = dataset['image_file']
8
  text = dataset['text']
9
  json_data = []
10
- images_path = 'drive/MyDrive/data_80k/output_images/'
11
  for i in range(len(labels)):
12
  json_data.append(
13
  {
 
3
  import os
4
 
5
 
6
+ dataset = pd.read_csv('data_80k/data.csv')
7
  labels = dataset['image_file']
8
  text = dataset['text']
9
  json_data = []
10
+ images_path = '/kaggle/input/hindi-english-images/data_80k/output_images/'
11
  for i in range(len(labels)):
12
  json_data.append(
13
  {