# import torch # from model_loader import model, processor, device # from processor_utils import load_input # from prompt import get_prompt # import json # def process_document(image): # # images = load_input(file_path) # # image = images[0] # # print("Checking input type and no of pages in pdf") # # print(type(image)) # # print(type(images)) # # print(len(images)) # messages = [ # { # "role": "user", # "content": [ # {"type": "image", "image": image}, # {"type": "text", "text": get_prompt()} # ] # } # ] # text = processor.apply_chat_template( # messages, # tokenize=False, # so that this can return string output # add_generation_prompt=True # if true it will add extra on start and end # ) # # print(f"The text of inference is {text}") # inputs = processor( # text=[text], # images=[image], # return_tensors="pt" # ).to(device) # # print(f"The inputs of inference is {inputs}") # output = model.generate( # **inputs, # max_new_tokens=1500, # do_sample=False, # if it is true there will be extra text with output # # temperature=0.1 # temp is not required # ) # # print(f"The output of inference is {output}") # generated_ids = output[0][inputs.input_ids.shape[-1]:] # # print(f"The generated_ids of inference is {generated_ids}") # # response = processor.decode( # past code # # generated_ids, # # skip_special_tokens=True # # ) # # return response.strip() # response = processor.decode( # generated_ids, # skip_special_tokens=True # ).strip() # # print(f"The response of inference is {response}") # # 🔥 FORCE JSON CLEANING # start = response.find("{") # end = response.rfind("}") + 1 # if start != -1 and end != -1: # response = response[start:end] # print(f"The type of response is before{response}") # try: # parsed = json.loads(response) # except: # parsed = { # "error":[ # response # ] # # "Invalid JSON", # # "raw": response # } # print(f"The type of response is after{response}") # return response # import json # from model_loader import get_model # from processor_utils import load_input # from prompt import get_part_classifier_prompt, get_part_prompt # def _run_model(image, prompt_text, model, processor, device): # messages = [ # { # "role": "user", # "content": [ # {"type": "image", "image": image}, # {"type": "text", "text": prompt_text} # ] # } # ] # text = processor.apply_chat_template( # messages, # tokenize=False, # add_generation_prompt=True # ) # inputs = processor( # text=[text], # images=[image], # return_tensors="pt" # ).to(device) # output = model.generate( # **inputs, # max_new_tokens=400, # do_sample=False # ) # generated_ids = output[0][inputs.input_ids.shape[-1]:] # response = processor.decode(generated_ids, skip_special_tokens=True).strip() # return response # def _extract_json_block(text): # start = text.find("{") # end = text.rfind("}") + 1 # if start == -1 or end == 0: # return None # return text[start:end] # def classify_page(image, model, processor, device): # raw = _run_model(image, get_part_classifier_prompt(), model, processor, device) # raw = raw.strip().upper() # valid_parts = {"PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"} # for part in valid_parts: # if part in raw: # return part # return "UNKNOWN" # def extract_part_json(image, part_name, model, processor, device): # raw = _run_model(image, get_part_prompt(part_name), model, processor, device) # json_block = _extract_json_block(raw) # if not json_block: # return { # "status": "error", # "part": part_name, # "raw_output": raw, # "parsed": None # } # try: # parsed = json.loads(json_block) # return { # "status": "success", # "part": part_name, # "raw_output": raw, # "parsed": parsed # } # except json.JSONDecodeError: # return { # "status": "error", # "part": part_name, # "raw_output": raw, # "parsed": None # } # def process_document(file_path): # model, processor, device = get_model() # pages = load_input(file_path) # page_results = [] # for idx, image in enumerate(pages, start=1): # part_name = classify_page(image, model, processor, device) # if part_name == "UNKNOWN": # page_results.append({ # "page_number": idx, # "status": "error", # "part": "UNKNOWN", # "raw_output": "", # "parsed": None # }) # continue # result = extract_part_json(image, part_name, model, processor, device) # result["page_number"] = idx # page_results.append(result) # return { # "total_pages": len(page_results), # "pages": page_results # } import json from model_loader import get_model from processor_utils import load_input from prompt import get_part_classifier_prompt, get_part_prompt import time def _get_max_tokens(part_name): limits = { "CLASSIFIER": 20, "PART-1": 1200, "PART-2": 700, "PART-3": 1800, "PART-4": 500, "PART-5": 300, "PART-6": 100 } return limits.get(part_name, 600) def _clean_raw_text(text): text = text.strip() if text.startswith("```json"): text = text[len("```json"):].strip() elif text.startswith("```"): text = text[len("```"):].strip() if text.endswith("```"): text = text[:-3].strip() return text def _run_model(image, prompt_text, model, processor, device, max_new_tokens): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text} ] } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[text], images=[image], return_tensors="pt" ).to(device) output = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) generated_ids = output[0][inputs.input_ids.shape[-1]:] response = processor.decode( generated_ids, skip_special_tokens=True ).strip() return _clean_raw_text(response) def _extract_json_block(text): start = text.find("{") end = text.rfind("}") + 1 if start == -1 or end == 0 or end <= start: return None return text[start:end] def classify_page(image, model, processor, device): raw = _run_model( image, get_part_classifier_prompt(), model, processor, device, max_new_tokens=_get_max_tokens("CLASSIFIER") ).upper() valid_parts = ["PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"] for part in valid_parts: if part in raw: return part return "UNKNOWN" def extract_part_json(image, part_name, model, processor, device): max_tokens = _get_max_tokens(part_name) raw = _run_model( image, get_part_prompt(part_name), model, processor, device, max_new_tokens=max_tokens ) json_block = _extract_json_block(raw) if json_block: try: parsed = json.loads(json_block) return { "status": "success", "part": part_name, "raw_output": raw, "parsed": parsed } except json.JSONDecodeError: pass # retry once with larger token budget retry_raw = _run_model( image, get_part_prompt(part_name), model, processor, device, max_new_tokens=max_tokens + 600 ) retry_json_block = _extract_json_block(retry_raw) if retry_json_block: try: parsed = json.loads(retry_json_block) return { "status": "success", "part": part_name, "raw_output": retry_raw, "parsed": parsed } except json.JSONDecodeError: pass return { "status": "error", "part": part_name, "raw_output": retry_raw if 'retry_raw' in locals() else raw, "parsed": None } def merge_page_results(page_results): final_json = {} for item in page_results: if item["status"] != "success" or not item["parsed"]: continue parsed = item["parsed"] for key, value in parsed.items(): final_json[key] = value return final_json def process_document(file_path): model, processor, device = get_model() pages = load_input(file_path) page_results = [] for idx, image in enumerate(pages, start=1): print("first model has been called for",idx,"image") start = time.time() part_name = classify_page(image, model, processor, device) end = time.time() print("total time taken by the first model",end-start,"sec") if part_name == "UNKNOWN": page_results.append({ "page_number": idx, "status": "error", "part": "UNKNOWN", "raw_output": "", "parsed": None }) continue print("second model has been called for",idx,"image") start = time.time() result = extract_part_json(image, part_name, model, processor, device) end = time.time() print("total time taken by the second model",end-start,"sec") print(result) result["page_number"] = idx page_results.append(result) final_json = merge_page_results(page_results) # return { # "final_json": final_json # # "total_pages": len(page_results), # # "pages": page_results # } return final_json