File size: 6,339 Bytes
889dd37
ddf0be2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c02e069
ddf0be2
 
 
 
 
 
 
 
b6f5cff
ddf0be2
b6f5cff
ddf0be2
 
525217d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddf0be2
 
 
261bbd2
ddf0be2
261bbd2
ddf0be2
261bbd2
ddf0be2
261bbd2
ddf0be2
 
6320557
ddf0be2
 
6320557
ddf0be2
 
261bbd2
ddf0be2
525217d
 
ddf0be2
 
525217d
ddf0be2
7cd0cdb
ddf0be2
7cd0cdb
 
 
 
 
 
 
 
 
b6f5cff
3c0f5c4
7cd0cdb
b6f5cff
3c0f5c4
7cd0cdb
b6f5cff
7cd0cdb
 
ddf0be2
 
 
 
 
 
 
 
 
 
7cd0cdb
 
 
b6f5cff
3c0f5c4
125bff3
 
 
cba8c8b
 
 
 
b6f5cff
7cd0cdb
ddf0be2
 
b6f5cff
3c0f5c4
ddf0be2
b6f5cff
3c0f5c4
ddf0be2
b6f5cff
ddf0be2
 
cba8c8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os,requests,ast,torch,re
import gradio as gr
import datetime as dt
import google.generativeai as genai
from io import BytesIO
from dotenv import load_dotenv
from PIL import Image,ImageDraw
from transformers import AutoProcessor,AutoModelForVision2Seq,LlavaForConditionalGeneration

# function for pulling secrets from local repositories
def get_secret(secret_key):
    if not os.getenv(secret_key): # usually used in other repos when github actions is utilized
        env_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..\..','.gitignore\.env'))
        load_dotenv(dotenv_path=env_path)
    
    value = os.getenv(secret_key)
    print(''.join(['*']*len(value)))
    if value is None:
        ValueError(f"Secret '{secret_key}' not found.")
    
    return value

# download an image when when provided a url
def get_image(url):
    # 1. Fetch the image and download the image
    try:
        response = requests.get(url,stream=True)
        response.raise_for_status()
        content = response.content

        #with open(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg', 'wb') as f:
        #    f.write(content)
    except requests.exceptions.RequestException as e:
        print(f'Error downloading image: {e}')
        exit()
    except IOError as e:
        print(f'Error saving image file: {e}')
        exit()
    return Image.open(BytesIO(content)).convert("RGB")

def load_model(model_name):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if 'llava' in model_name:
        model = LlavaForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        ).to(device)
    else:
        model = AutoModelForVision2Seq.from_pretrained(model_name).to(device)
    print(f"model: {model}")
    processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
    print(f"processor: {processor}")
    return processor,model

def request_manager(model_name,url):
    image = get_image(url)
    print(f"image: {image}")

    system_prompt = f"""
    You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
    This is usually identified in a location outside of the main content on the document, and usually on the bottom
    right or left of the document. The rotation of the number may differ based on images. Furthermore the ID number
    is usually a string of numbers, around 9 number characters in length. Could possibly have alphabetic characters
    as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
    values should fit into the image size which is {image.size}.
    """
    print(f"system_prompt: {system_prompt}")

    if 'gemini' in model_name:
        return_packet = gemini_identify_id(model_name,image,system_prompt)
    elif 'llava' in model_name:
        return_packet = huggingface_llava_15_7b_hf(model_name,image,system_prompt)
    return return_packet

def gemini_identify_id(model_name,image,system_prompt):
    # 2. Function to process image with Gemini Pro Vision
    try:
        genai.configure(api_key=get_secret('GEMINI_API'))
        print(f"genai: {genai}")
        model = genai.GenerativeModel("gemini-2.0-flash")
        print(f"model: {model}")
        response = model.generate_content([system_prompt, image])
        print(f"response: {response}")
        response_text = response.text
        print(f"response: {response_text}")
        if not response_text:
            print('Could not find an ID number')
            return [image,'no response was received']
    
    except Exception as e:
        return [image,f"Error processing image: {str(e)}"]
    
    draw = ImageDraw.Draw(image)
    print(f"draw: {draw}")
    draw.rectangle(ast.literal_eval(response_text),outline='yellow',width=5)
    #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
    return [image,response_text]

# Huggingface repo usage
def huggingface_llava_15_7b_hf(model_name,image,system_prompt):
    try:
        #image = get_image(url)
        processor,model=load_model(model_name)
        conversation = [
            {
                "role":"user",
                "content":[
                    {"type":"text","text":system_prompt},
                    {"type":"image"},
                ],
            },
        ]
        print(f"conversation: {conversation}")
        
        prompt = processor.apply_chat_template(conversation,add_generation_prompt=True)
        print(f"prompt: {prompt}")
        
        inputs = processor(images=image,text=prompt,return_tensors="pt").to(model.device)
        print(f"inputs: {inputs}")

        """
        with torch.no_grad():
            output = model.generate(**inputs)
        
        response_text = processor.batch_decode(output,skip_special_tokens=True)[0]
        print(response_text)
        try:
            bbox = ast.literal_eval(response_text)
        except Exception as e:
            print(f"Error parsing bounding box response: {str(e)}")
            return None
        """

        output = model.generate(**inputs,max_new_tokens=200,do_sample=False)
        print(f"output: {output}")
        
        response_string = processor.decode(output[0][2:],skip_special_tokens=True)
        print(f"response_string: {response_string}")

        match = re.search(r"ASSISTANT: \[(.*?)\]",response_string)
        if not match:
            return [image,"no match found"]
        bbox = [image.size[0],image.size[1],image.size[0],image.size[1]]*ast.literal_eval([match.group(1)])
        print(f"bbox: {bbox}")

        
        draw = ImageDraw.Draw(image)
        print(f"draw: {draw}")
        
        draw.rectangle(bbox,outline="red",width=5)
        print(f"image: {image}")
        
        #image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
        return [image,bbox]
    except Exception as e:
        print(f"Error loading model or processing image: {str(e)}")
        return [image,"an error occurred processing request"]