Spaces:
Sleeping
Sleeping
File size: 6,339 Bytes
889dd37 ddf0be2 c02e069 ddf0be2 b6f5cff ddf0be2 b6f5cff ddf0be2 525217d ddf0be2 261bbd2 ddf0be2 261bbd2 ddf0be2 261bbd2 ddf0be2 261bbd2 ddf0be2 6320557 ddf0be2 6320557 ddf0be2 261bbd2 ddf0be2 525217d ddf0be2 525217d ddf0be2 7cd0cdb ddf0be2 7cd0cdb b6f5cff 3c0f5c4 7cd0cdb b6f5cff 3c0f5c4 7cd0cdb b6f5cff 7cd0cdb ddf0be2 7cd0cdb b6f5cff 3c0f5c4 125bff3 cba8c8b b6f5cff 7cd0cdb ddf0be2 b6f5cff 3c0f5c4 ddf0be2 b6f5cff 3c0f5c4 ddf0be2 b6f5cff ddf0be2 cba8c8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os,requests,ast,torch,re
import gradio as gr
import datetime as dt
import google.generativeai as genai
from io import BytesIO
from dotenv import load_dotenv
from PIL import Image,ImageDraw
from transformers import AutoProcessor,AutoModelForVision2Seq,LlavaForConditionalGeneration
# function for pulling secrets from local repositories
def get_secret(secret_key):
if not os.getenv(secret_key): # usually used in other repos when github actions is utilized
env_path = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..\..','.gitignore\.env'))
load_dotenv(dotenv_path=env_path)
value = os.getenv(secret_key)
print(''.join(['*']*len(value)))
if value is None:
ValueError(f"Secret '{secret_key}' not found.")
return value
# download an image when when provided a url
def get_image(url):
# 1. Fetch the image and download the image
try:
response = requests.get(url,stream=True)
response.raise_for_status()
content = response.content
#with open(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg', 'wb') as f:
# f.write(content)
except requests.exceptions.RequestException as e:
print(f'Error downloading image: {e}')
exit()
except IOError as e:
print(f'Error saving image file: {e}')
exit()
return Image.open(BytesIO(content)).convert("RGB")
def load_model(model_name):
device = "cuda" if torch.cuda.is_available() else "cpu"
if 'llava' in model_name:
model = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(device)
else:
model = AutoModelForVision2Seq.from_pretrained(model_name).to(device)
print(f"model: {model}")
processor = AutoProcessor.from_pretrained(model_name,use_fast=True)
print(f"processor: {processor}")
return processor,model
def request_manager(model_name,url):
image = get_image(url)
print(f"image: {image}")
system_prompt = f"""
You are an AI document processing assistant. Analyze the provided image. Identify the ID number in the document.
This is usually identified in a location outside of the main content on the document, and usually on the bottom
right or left of the document. The rotation of the number may differ based on images. Furthermore the ID number
is usually a string of numbers, around 9 number characters in length. Could possibly have alphabetic characters
as well but that looks to be rare. The output should only be a string in the format [x0,y0,x1,y1], and the
values should fit into the image size which is {image.size}.
"""
print(f"system_prompt: {system_prompt}")
if 'gemini' in model_name:
return_packet = gemini_identify_id(model_name,image,system_prompt)
elif 'llava' in model_name:
return_packet = huggingface_llava_15_7b_hf(model_name,image,system_prompt)
return return_packet
def gemini_identify_id(model_name,image,system_prompt):
# 2. Function to process image with Gemini Pro Vision
try:
genai.configure(api_key=get_secret('GEMINI_API'))
print(f"genai: {genai}")
model = genai.GenerativeModel("gemini-2.0-flash")
print(f"model: {model}")
response = model.generate_content([system_prompt, image])
print(f"response: {response}")
response_text = response.text
print(f"response: {response_text}")
if not response_text:
print('Could not find an ID number')
return [image,'no response was received']
except Exception as e:
return [image,f"Error processing image: {str(e)}"]
draw = ImageDraw.Draw(image)
print(f"draw: {draw}")
draw.rectangle(ast.literal_eval(response_text),outline='yellow',width=5)
#image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
return [image,response_text]
# Huggingface repo usage
def huggingface_llava_15_7b_hf(model_name,image,system_prompt):
try:
#image = get_image(url)
processor,model=load_model(model_name)
conversation = [
{
"role":"user",
"content":[
{"type":"text","text":system_prompt},
{"type":"image"},
],
},
]
print(f"conversation: {conversation}")
prompt = processor.apply_chat_template(conversation,add_generation_prompt=True)
print(f"prompt: {prompt}")
inputs = processor(images=image,text=prompt,return_tensors="pt").to(model.device)
print(f"inputs: {inputs}")
"""
with torch.no_grad():
output = model.generate(**inputs)
response_text = processor.batch_decode(output,skip_special_tokens=True)[0]
print(response_text)
try:
bbox = ast.literal_eval(response_text)
except Exception as e:
print(f"Error parsing bounding box response: {str(e)}")
return None
"""
output = model.generate(**inputs,max_new_tokens=200,do_sample=False)
print(f"output: {output}")
response_string = processor.decode(output[0][2:],skip_special_tokens=True)
print(f"response_string: {response_string}")
match = re.search(r"ASSISTANT: \[(.*?)\]",response_string)
if not match:
return [image,"no match found"]
bbox = [image.size[0],image.size[1],image.size[0],image.size[1]]*ast.literal_eval([match.group(1)])
print(f"bbox: {bbox}")
draw = ImageDraw.Draw(image)
print(f"draw: {draw}")
draw.rectangle(bbox,outline="red",width=5)
print(f"image: {image}")
#image.save(f'{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}\\download\\{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.jpg')
return [image,bbox]
except Exception as e:
print(f"Error loading model or processing image: {str(e)}")
return [image,"an error occurred processing request"] |