bcvilnrotter commited on
Commit
b6f5cff
·
verified ·
1 Parent(s): 3c0f5c4

Update utils/basic_functions.py

Browse files
Files changed (1) hide show
  1. utils/basic_functions.py +12 -11
utils/basic_functions.py CHANGED
@@ -48,7 +48,9 @@ def load_model(model_name):
48
  ).to(device)
49
  else:
50
  model = AutoModelForVision2Seq.from_pretrained(model_name).to(device)
 
51
  processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
 
52
  return processor,model
53
 
54
  def gemini_identify_id(url,system_prompt):
@@ -78,7 +80,7 @@ def huggingface_detect_id_box(model_name,url):
78
  try:
79
  #image = get_image(url)
80
  image = Image.open(requests.get(url,stream=True).raw)
81
- print(image)
82
 
83
  system_prompt = f"""
84
  You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
@@ -88,10 +90,9 @@ def huggingface_detect_id_box(model_name,url):
88
  as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
89
  values should fit into the image size which is {image.size}.
90
  """
91
- print(system_prompt)
92
 
93
  processor,model=load_model(model_name)
94
- print(processor,model)
95
 
96
  conversation = [
97
  {
@@ -102,13 +103,13 @@ def huggingface_detect_id_box(model_name,url):
102
  ],
103
  },
104
  ]
105
- print(conversation)
106
 
107
  prompt = processor.apply_chat_template(conversation,add_generation_prompt=True)
108
- print(prompt)
109
 
110
  inputs = processor(images=image,text=prompt,return_tensors="pt").to(model.device)
111
- print(inputs)
112
 
113
  """
114
  with torch.no_grad():
@@ -124,20 +125,20 @@ def huggingface_detect_id_box(model_name,url):
124
  """
125
 
126
  output = model.generate(**inputs,max_new_tokens=200,do_sample=False)
127
- print(output)
128
 
129
  bbox = processor.decode(output[0][2:],skip_special_tokens=True)
130
- print(bbox)
131
 
132
 
133
  draw = ImageDraw.Draw(image)
134
- print(draw)
135
 
136
  draw.rectangle(bbox,outline="red",width=5)
137
- print(image)
138
 
139
  #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
140
- return image,bbox
141
  except Exception as e:
142
  print(f"Error loading model or processing image: {str(e)}")
143
  return None
 
48
  ).to(device)
49
  else:
50
  model = AutoModelForVision2Seq.from_pretrained(model_name).to(device)
51
+ print(f"model: {model}")
52
  processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
53
+ print(f"processor: {processor}")
54
  return processor,model
55
 
56
  def gemini_identify_id(url,system_prompt):
 
80
  try:
81
  #image = get_image(url)
82
  image = Image.open(requests.get(url,stream=True).raw)
83
+ print(f"image: {image}")
84
 
85
  system_prompt = f"""
86
  You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
 
90
  as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
91
  values should fit into the image size which is {image.size}.
92
  """
93
+ print(f"system_prompt: {system_prompt}")
94
 
95
  processor,model=load_model(model_name)
 
96
 
97
  conversation = [
98
  {
 
103
  ],
104
  },
105
  ]
106
+ print(f"conversation: {conversation}")
107
 
108
  prompt = processor.apply_chat_template(conversation,add_generation_prompt=True)
109
+ print(f"prompt: {prompt}")
110
 
111
  inputs = processor(images=image,text=prompt,return_tensors="pt").to(model.device)
112
+ print(f"inputs: {inputs}")
113
 
114
  """
115
  with torch.no_grad():
 
125
  """
126
 
127
  output = model.generate(**inputs,max_new_tokens=200,do_sample=False)
128
+ print(f"output: {output}")
129
 
130
  bbox = processor.decode(output[0][2:],skip_special_tokens=True)
131
+ print(f"bbox: {bbox}")
132
 
133
 
134
  draw = ImageDraw.Draw(image)
135
+ print(f"draw: {draw}")
136
 
137
  draw.rectangle(bbox,outline="red",width=5)
138
+ print(f"image: {image}")
139
 
140
  #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
141
+ return [image,bbox]
142
  except Exception as e:
143
  print(f"Error loading model or processing image: {str(e)}")
144
  return None