bcvilnrotter commited on
Commit
ddf0be2
·
verified ·
1 Parent(s): 6532b24

Update utils/basic_functions.py

Browse files
Files changed (1) hide show
  1. utils/basic_functions.py +108 -107
utils/basic_functions.py CHANGED
@@ -1,108 +1,109 @@
1
- import os,requests,ast,torch
2
- import gradio as gr
3
- import datetime as dt
4
- import google.generativeai as genai
5
- from io import BytesIO
6
- from dotenv import load_dotenv
7
- from PIL import Image,ImageDraw
8
- from transformers import AutoProcessor,AutoModelForVision2Seq,LlavaForConditionalGeneration
9
-
10
- # function for pulling secrets from local repositories
11
- def get_secret(secret_key):
12
- if not os.getenv(secret_key): # usually used in other repos when github actions is utilized
13
- env_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..\..','.gitignore\.env'))
14
- load_dotenv(dotenv_path=env_path)
15
-
16
- value = os.getenv(secret_key)
17
- print(''.join(['*']*len(value)))
18
- if value is None:
19
- ValueError(f"Secret '{secret_key}' not found.")
20
-
21
- return value
22
-
23
- # download an image when when provided a url
24
- def get_image(url):
25
- # 1. Fetch the image and download the image
26
- try:
27
- response = requests.get(url,stream=True)
28
- response.raise_for_status()
29
- content = response.content
30
-
31
- #with open(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg', 'wb') as f:
32
- # f.write(content)
33
- except requests.exceptions.RequestException as e:
34
- print(f'Error downloading image: {e}')
35
- exit()
36
- except IOError as e:
37
- print(f'Error saving image file: {e}')
38
- exit()
39
- return Image.open(BytesIO(content)).convert("RGB")
40
-
41
- def load_model(model_name):
42
- if 'llava' in model_name:
43
- model = LlavaForConditionalGeneration.from_pretrained(
44
- model_name,
45
- torch_dtype=torch.float16,
46
- low_cpu_mem_usage=True,
47
- ).to(0)
48
- else:
49
- model = AutoModelForVision2Seq.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
50
- processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
51
- return processor,model
52
-
53
- def gemini_identify_id(url,system_prompt):
54
- # 2. Function to process image with Gemini Pro Vision
55
- try:
56
- image = get_image(url)
57
-
58
- genai.configure(api_key=get_secret('GEMINI_API'))
59
-
60
- model = genai.GenerativeModel("gemini-2.0-flash")
61
- response = model.generate_content([system_prompt, image])
62
- response_text = response.text
63
- if not response_text:
64
- print('Could not find an ID number')
65
- exit()
66
- print(response_text)
67
-
68
- except Exception as e:
69
- return f"Error processing image: {str(e)}",None
70
-
71
- draw = ImageDraw.Draw(image)
72
- draw.rectangle(ast.literal_eval(response_text),outline='yellow',width=5)
73
- image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
74
-
75
- # Huggingface repo usage
76
- def huggingface_detect_id_box(model_name,url):
77
- try:
78
- image = get_image(url)
79
-
80
- system_prompt = f"""
81
- You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
82
- This is usually identified in a location outside of the main content on the document, and usually on the bottom
83
- right or left of the document. The rotation of the number may differ based on images. Furthermore the ID number
84
- is usually a string of numbers, around 9 number characters in length. Could possibly have alphabetic characters
85
- as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
86
- values should fit into the image size which is {image.size}.
87
- """
88
-
89
- processor,model=load_model(model_name)
90
- inputs = processor(image,text=system_prompt,return_tensors="pt").to(model.device)
91
- with torch.no_grad():
92
- output = model.generate(**inputs)
93
-
94
- response_text = processor.batch_decode(output,skip_special_tokens=True)[0]
95
- print(response_text)
96
- try:
97
- bbox = ast.literal_eval(response_text)
98
- except Exception as e:
99
- print(f"Error parsing bounding box response: {str(e)}")
100
- return None
101
-
102
- draw = ImageDraw.Draw(image)
103
- draw.rectangle(bbox,outline="red",width=5)
104
- #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
105
- return image
106
- except Exception as e:
107
- print(f"Error loading model or processing image: {str(e)}")
 
108
  return None
 
1
+ import os,requests,ast,torch
2
+ import gradio as gr
3
+ import datetime as dt
4
+ import google.generativeai as genai
5
+ from io import BytesIO
6
+ from dotenv import load_dotenv
7
+ from PIL import Image,ImageDraw
8
+ from transformers import AutoProcessor,AutoModelForVision2Seq,LlavaForConditionalGeneration
9
+
10
+ # function for pulling secrets from local repositories
11
+ def get_secret(secret_key):
12
+ if not os.getenv(secret_key): # usually used in other repos when github actions is utilized
13
+ env_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..\..','.gitignore\.env'))
14
+ load_dotenv(dotenv_path=env_path)
15
+
16
+ value = os.getenv(secret_key)
17
+ print(''.join(['*']*len(value)))
18
+ if value is None:
19
+ ValueError(f"Secret '{secret_key}' not found.")
20
+
21
+ return value
22
+
23
+ # download an image when when provided a url
24
+ def get_image(url):
25
+ # 1. Fetch the image and download the image
26
+ try:
27
+ response = requests.get(url,stream=True)
28
+ response.raise_for_status()
29
+ content = response.content
30
+
31
+ #with open(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg', 'wb') as f:
32
+ # f.write(content)
33
+ except requests.exceptions.RequestException as e:
34
+ print(f'Error downloading image: {e}')
35
+ exit()
36
+ except IOError as e:
37
+ print(f'Error saving image file: {e}')
38
+ exit()
39
+ return Image.open(BytesIO(content)).convert("RGB")
40
+
41
+ def load_model(model_name):
42
+ device = "cude" if torch.cudea.is_available() else "cpu"
43
+ if 'llava' in model_name:
44
+ model = LlavaForConditionalGeneration.from_pretrained(
45
+ model_name,
46
+ torch_dtype=torch.float16,
47
+ low_cpu_mem_usage=True,
48
+ ).to(device)
49
+ else:
50
+ model = AutoModelForVision2Seq.from_pretrained(model_name).to(device)
51
+ processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
52
+ return processor,model
53
+
54
+ def gemini_identify_id(url,system_prompt):
55
+ # 2. Function to process image with Gemini Pro Vision
56
+ try:
57
+ image = get_image(url)
58
+
59
+ genai.configure(api_key=get_secret('GEMINI_API'))
60
+
61
+ model = genai.GenerativeModel("gemini-2.0-flash")
62
+ response = model.generate_content([system_prompt, image])
63
+ response_text = response.text
64
+ if not response_text:
65
+ print('Could not find an ID number')
66
+ exit()
67
+ print(response_text)
68
+
69
+ except Exception as e:
70
+ return f"Error processing image: {str(e)}",None
71
+
72
+ draw = ImageDraw.Draw(image)
73
+ draw.rectangle(ast.literal_eval(response_text),outline='yellow',width=5)
74
+ image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
75
+
76
+ # Huggingface repo usage
77
+ def huggingface_detect_id_box(model_name,url):
78
+ try:
79
+ image = get_image(url)
80
+
81
+ system_prompt = f"""
82
+ You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
83
+ This is usually identified in a location outside of the main content on the document, and usually on the bottom
84
+ right or left of the document. The rotation of the number may differ based on images. Furthermore the ID number
85
+ is usually a string of numbers, around 9 number characters in length. Could possibly have alphabetic characters
86
+ as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
87
+ values should fit into the image size which is {image.size}.
88
+ """
89
+
90
+ processor,model=load_model(model_name)
91
+ inputs = processor(image,text=system_prompt,return_tensors="pt").to(model.device)
92
+ with torch.no_grad():
93
+ output = model.generate(**inputs)
94
+
95
+ response_text = processor.batch_decode(output,skip_special_tokens=True)[0]
96
+ print(response_text)
97
+ try:
98
+ bbox = ast.literal_eval(response_text)
99
+ except Exception as e:
100
+ print(f"Error parsing bounding box response: {str(e)}")
101
+ return None
102
+
103
+ draw = ImageDraw.Draw(image)
104
+ draw.rectangle(bbox,outline="red",width=5)
105
+ #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
106
+ return image
107
+ except Exception as e:
108
+ print(f"Error loading model or processing image: {str(e)}")
109
  return None