testing / inference.py
credent007's picture
Update inference.py
21d1711 verified
# import torch
# from model_loader import model, processor, device
# from processor_utils import load_input
# from prompt import get_prompt
# import json
# def process_document(image):
# # images = load_input(file_path)
# # image = images[0]
# # print("Checking input type and no of pages in pdf")
# # print(type(image))
# # print(type(images))
# # print(len(images))
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image", "image": image},
# {"type": "text", "text": get_prompt()}
# ]
# }
# ]
# text = processor.apply_chat_template(
# messages,
# tokenize=False, # so that this can return string output
# add_generation_prompt=True # if true it will add extra on start and end
# )
# # print(f"The text of inference is {text}")
# inputs = processor(
# text=[text],
# images=[image],
# return_tensors="pt"
# ).to(device)
# # print(f"The inputs of inference is {inputs}")
# output = model.generate(
# **inputs,
# max_new_tokens=1500,
# do_sample=False, # if it is true there will be extra text with output
# # temperature=0.1 # temp is not required
# )
# # print(f"The output of inference is {output}")
# generated_ids = output[0][inputs.input_ids.shape[-1]:]
# # print(f"The generated_ids of inference is {generated_ids}")
# # response = processor.decode( # past code
# # generated_ids,
# # skip_special_tokens=True
# # )
# # return response.strip()
# response = processor.decode(
# generated_ids,
# skip_special_tokens=True
# ).strip()
# # print(f"The response of inference is {response}")
# # 🔥 FORCE JSON CLEANING
# start = response.find("{")
# end = response.rfind("}") + 1
# if start != -1 and end != -1:
# response = response[start:end]
# print(f"The type of response is before{response}")
# try:
# parsed = json.loads(response)
# except:
# parsed = {
# "error":[
# response
# ]
# # "Invalid JSON",
# # "raw": response
# }
# print(f"The type of response is after{response}")
# return response
# import json
# from model_loader import get_model
# from processor_utils import load_input
# from prompt import get_part_classifier_prompt, get_part_prompt
# def _run_model(image, prompt_text, model, processor, device):
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image", "image": image},
# {"type": "text", "text": prompt_text}
# ]
# }
# ]
# text = processor.apply_chat_template(
# messages,
# tokenize=False,
# add_generation_prompt=True
# )
# inputs = processor(
# text=[text],
# images=[image],
# return_tensors="pt"
# ).to(device)
# output = model.generate(
# **inputs,
# max_new_tokens=400,
# do_sample=False
# )
# generated_ids = output[0][inputs.input_ids.shape[-1]:]
# response = processor.decode(generated_ids, skip_special_tokens=True).strip()
# return response
# def _extract_json_block(text):
# start = text.find("{")
# end = text.rfind("}") + 1
# if start == -1 or end == 0:
# return None
# return text[start:end]
# def classify_page(image, model, processor, device):
# raw = _run_model(image, get_part_classifier_prompt(), model, processor, device)
# raw = raw.strip().upper()
# valid_parts = {"PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"}
# for part in valid_parts:
# if part in raw:
# return part
# return "UNKNOWN"
# def extract_part_json(image, part_name, model, processor, device):
# raw = _run_model(image, get_part_prompt(part_name), model, processor, device)
# json_block = _extract_json_block(raw)
# if not json_block:
# return {
# "status": "error",
# "part": part_name,
# "raw_output": raw,
# "parsed": None
# }
# try:
# parsed = json.loads(json_block)
# return {
# "status": "success",
# "part": part_name,
# "raw_output": raw,
# "parsed": parsed
# }
# except json.JSONDecodeError:
# return {
# "status": "error",
# "part": part_name,
# "raw_output": raw,
# "parsed": None
# }
# def process_document(file_path):
# model, processor, device = get_model()
# pages = load_input(file_path)
# page_results = []
# for idx, image in enumerate(pages, start=1):
# part_name = classify_page(image, model, processor, device)
# if part_name == "UNKNOWN":
# page_results.append({
# "page_number": idx,
# "status": "error",
# "part": "UNKNOWN",
# "raw_output": "",
# "parsed": None
# })
# continue
# result = extract_part_json(image, part_name, model, processor, device)
# result["page_number"] = idx
# page_results.append(result)
# return {
# "total_pages": len(page_results),
# "pages": page_results
# }
import json
from model_loader import get_model
from processor_utils import load_input
# from prompt import get_part_classifier_prompt, get_part_prompt
from prompt1 import get_part_classifier_prompt, get_part_prompt
import time
def _get_max_tokens(part_name):
limits = {
"CLASSIFIER": 20,
"PART-1": 1200,
"PART-2": 700,
"PART-3": 1800,
"PART-4": 500,
"PART-5": 300,
"PART-6": 100
}
return limits.get(part_name, 600)
def _clean_raw_text(text):
text = text.strip()
if text.startswith("```json"):
text = text[len("```json"):].strip()
elif text.startswith("```"):
text = text[len("```"):].strip()
if text.endswith("```"):
text = text[:-3].strip()
return text
def _run_model(image, prompt_text, model, processor, device, max_new_tokens):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text}
]
}
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = processor(
text=[text],
images=[image],
return_tensors="pt"
).to(device)
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
generated_ids = output[0][inputs.input_ids.shape[-1]:]
response = processor.decode(
generated_ids,
skip_special_tokens=True
).strip()
return _clean_raw_text(response)
def _extract_json_block(text):
start = text.find("{")
end = text.rfind("}") + 1
if start == -1 or end == 0 or end <= start:
return None
return text[start:end]
def classify_page(image, model, processor, device):
raw = _run_model(
image,
get_part_classifier_prompt(),
model,
processor,
device,
max_new_tokens=_get_max_tokens("CLASSIFIER")
).upper()
valid_parts = ["PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"]
for part in valid_parts:
if part in raw:
return part
return "UNKNOWN"
def extract_part_json(image, part_name, model, processor, device):
max_tokens = _get_max_tokens(part_name)
raw = _run_model(
image,
get_part_prompt(part_name),
model,
processor,
device,
max_new_tokens=max_tokens
)
json_block = _extract_json_block(raw)
if json_block:
try:
parsed = json.loads(json_block)
return {
"status": "success",
"part": part_name,
"raw_output": raw,
"parsed": parsed
}
except json.JSONDecodeError:
pass
# retry once with larger token budget
retry_raw = _run_model(
image,
get_part_prompt(part_name),
model,
processor,
device,
max_new_tokens=max_tokens + 600
)
retry_json_block = _extract_json_block(retry_raw)
if retry_json_block:
try:
parsed = json.loads(retry_json_block)
return {
"status": "success",
"part": part_name,
"raw_output": retry_raw,
"parsed": parsed
}
except json.JSONDecodeError:
pass
return {
"status": "error",
"part": part_name,
"raw_output": retry_raw if 'retry_raw' in locals() else raw,
"parsed": None
}
# def merge_page_results(page_results):
# final_json = {}
# for item in page_results:
# if item["status"] != "success" or not item["parsed"]:
# continue
# parsed = item["parsed"]
# for key, value in parsed.items():
# final_json[key] = value
# return final_json
# Adding these to handle json in structured format add from line 381 to 425
def merge_page_results(page_results):
final_json = {
"PART-1": {},
"PART-2": {},
"PART-3": {},
"PART-4": {},
"PART-5": {},
"PART-6": {}
}
for item in page_results:
if item["status"] != "success" or not item["parsed"]:
continue
part = item["part"]
parsed = item["parsed"]
final_json[part] = _merge_values(final_json[part], parsed)
return {key: value for key, value in final_json.items() if value}
def _merge_values(old_value, new_value):
if old_value is None:
return new_value
if isinstance(old_value, list) and isinstance(new_value, list):
return old_value + new_value
if isinstance(old_value, dict) and isinstance(new_value, dict):
merged = dict(old_value)
for key, value in new_value.items():
if key in merged:
merged[key] = _merge_values(merged[key], value)
else:
merged[key] = value
return merged
if old_value in ("", None, [], {}):
return new_value
return old_value
def process_document(file_path):
model, processor, device = get_model()
pages = load_input(file_path)
page_results = []
for idx, image in enumerate(pages, start=1):
print("first model has been called for",idx,"image")
start = time.time()
part_name = classify_page(image, model, processor, device)
end = time.time()
print("total time taken by the first model",end-start,"sec")
if part_name == "UNKNOWN":
page_results.append({
"page_number": idx,
"status": "error",
"part": "UNKNOWN",
"raw_output": "",
"parsed": None
})
continue
print("second model has been called for",idx,"image")
start = time.time()
result = extract_part_json(image, part_name, model, processor, device)
end = time.time()
print("total time taken by the second model",end-start,"sec")
result["page_number"] = idx
page_results.append(result)
final_json = merge_page_results(page_results)
# return {
# "final_json": final_json
# # "total_pages": len(page_results),
# # "pages": page_results
# }
return final_json