bcvilnrotter commited on
Commit
188f56d
·
verified ·
1 Parent(s): 44880b4

Upload basic_functions.py

Browse files
Files changed (1) hide show
  1. utils/basic_functions.py +108 -0
utils/basic_functions.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,requests,ast,torch
2
+ import gradio as gr
3
+ import datetime as dt
4
+ import google.generativeai as genai
5
+ from io import BytesIO
6
+ from dotenv import load_dotenv
7
+ from PIL import Image,ImageDraw
8
+ from transformers import AutoProcessor,AutoModelForVision2Seq,LlavaForConditionalGeneration
9
+
10
+ # function for pulling secrets from local repositories
11
+ def get_secret(secret_key):
12
+ if not os.getenv(secret_key): # usually used in other repos when github actions is utilized
13
+ env_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..\..','.gitignore\.env'))
14
+ load_dotenv(dotenv_path=env_path)
15
+
16
+ value = os.getenv(secret_key)
17
+ print(''.join(['*']*len(value)))
18
+ if value is None:
19
+ ValueError(f"Secret '{secret_key}' not found.")
20
+
21
+ return value
22
+
23
+ # download an image when when provided a url
24
+ def get_image(url):
25
+ # 1. Fetch the image and download the image
26
+ try:
27
+ response = requests.get(url,stream=True)
28
+ response.raise_for_status()
29
+ content = response.content
30
+
31
+ #with open(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg', 'wb') as f:
32
+ # f.write(content)
33
+ except requests.exceptions.RequestException as e:
34
+ print(f'Error downloading image: {e}')
35
+ exit()
36
+ except IOError as e:
37
+ print(f'Error saving image file: {e}')
38
+ exit()
39
+ return Image.open(BytesIO(content)).convert("RGB")
40
+
41
+ def load_model(model_name):
42
+ if 'llava' in model_name:
43
+ model = LlavaForConditionalGeneration.from_pretrained(
44
+ model_name,
45
+ torch_dtype=torch.float16,
46
+ low_cpu_mem_usage=True,
47
+ ).to(0)
48
+ else:
49
+ model = AutoModelForVision2Seq.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
50
+ processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
51
+ return processor,model
52
+
53
+ def gemini_identify_id(url,system_prompt):
54
+ # 2. Function to process image with Gemini Pro Vision
55
+ try:
56
+ image = get_image(url)
57
+
58
+ genai.configure(api_key=get_secret('GEMINI_API'))
59
+
60
+ model = genai.GenerativeModel("gemini-2.0-flash")
61
+ response = model.generate_content([system_prompt, image])
62
+ response_text = response.text
63
+ if not response_text:
64
+ print('Could not find an ID number')
65
+ exit()
66
+ print(response_text)
67
+
68
+ except Exception as e:
69
+ return f"Error processing image: {str(e)}",None
70
+
71
+ draw = ImageDraw.Draw(image)
72
+ draw.rectangle(ast.literal_eval(response_text),outline='yellow',width=5)
73
+ image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
74
+
75
+ # Huggingface repo usage
76
+ def huggingface_detect_id_box(model_name,url):
77
+ try:
78
+ image = get_image(url)
79
+
80
+ system_prompt = f"""
81
+ You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
82
+ This is usually identified in a location outside of the main content on the document, and usually on the bottom
83
+ right or left of the document. The rotation of the number may differ based on images. Furthermore the ID number
84
+ is usually a string of numbers, around 9 number characters in length. Could possibly have alphabetic characters
85
+ as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
86
+ values should fit into the image size which is {image.size}.
87
+ """
88
+
89
+ processor,model=load_model(model_name)
90
+ inputs = processor(image,text=system_prompt,return_tensors="pt").to(model.device)
91
+ with torch.no_grad():
92
+ output = model.generate(**inputs)
93
+
94
+ response_text = processor.batch_decode(output,skip_special_tokens=True)[0]
95
+ print(response_text)
96
+ try:
97
+ bbox = ast.literal_eval(response_text)
98
+ except Exception as e:
99
+ print(f"Error parsing bounding box response: {str(e)}")
100
+ return None
101
+
102
+ draw = ImageDraw.Draw(image)
103
+ draw.rectangle(bbox,outline="red",width=5)
104
+ #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
105
+ return image
106
+ except Exception as e:
107
+ print(f"Error loading model or processing image: {str(e)}")
108
+ return None