import json from model_loader import get_model from processor_utils import load_input from prompt import get_part_classifier_prompt, get_part_prompt import time def _get_max_tokens(part_name): limits = { "CLASSIFIER": 20, "PART-1": 1200, "PART-2": 700, "PART-3": 1800, "PART-4": 500, "PART-5": 300, "PART-6": 100 } return limits.get(part_name, 600) def _clean_raw_text(text): text = text.strip() if text.startswith("```json"): text = text[len("```json"):].strip() elif text.startswith("```"): text = text[len("```"):].strip() if text.endswith("```"): text = text[:-3].strip() return text def _run_model(image, prompt_text, model, processor, device, max_new_tokens): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text} ] } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[text], images=[image], return_tensors="pt" ).to(device) output = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) generated_ids = output[0][inputs.input_ids.shape[-1]:] response = processor.decode( generated_ids, skip_special_tokens=True ).strip() return _clean_raw_text(response) def _extract_json_block(text): start = text.find("{") end = text.rfind("}") + 1 if start == -1 or end == 0 or end <= start: return None return text[start:end] def classify_page(image, model, processor, device): raw = _run_model( image, get_part_classifier_prompt(), model, processor, device, max_new_tokens=_get_max_tokens("CLASSIFIER") ).upper() valid_parts = ["PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"] for part in valid_parts: if part in raw: return part return "UNKNOWN" def extract_part_json(image, part_name, model, processor, device): max_tokens = _get_max_tokens(part_name) raw = _run_model( image, get_part_prompt(part_name), model, processor, device, max_new_tokens=max_tokens ) json_block = _extract_json_block(raw) if json_block: try: parsed = json.loads(json_block) return { "status": "success", "part": part_name, "raw_output": raw, "parsed": parsed } except json.JSONDecodeError: pass # retry once with larger token budget retry_raw = _run_model( image, get_part_prompt(part_name), model, processor, device, max_new_tokens=max_tokens + 600 ) retry_json_block = _extract_json_block(retry_raw) if retry_json_block: try: parsed = json.loads(retry_json_block) return { "status": "success", "part": part_name, "raw_output": retry_raw, "parsed": parsed } except json.JSONDecodeError: pass return { "status": "error", "part": part_name, "raw_output": retry_raw if 'retry_raw' in locals() else raw, "parsed": None } # def merge_page_results(page_results): # final_json = {} # for item in page_results: # if item["status"] != "success" or not item["parsed"]: # continue # parsed = item["parsed"] # for key, value in parsed.items(): # final_json[key] = value # return final_json # Adding these to handle json in structured format add from line 381 to 425 def merge_page_results(page_results): final_json = { "PART-1": {}, "PART-2": {}, "PART-3": {}, "PART-4": {}, "PART-5": {}, "PART-6": {} } for item in page_results: if item["status"] != "success" or not item["parsed"]: continue part = item["part"] parsed = item["parsed"] final_json[part] = _merge_values(final_json[part], parsed) return {key: value for key, value in final_json.items() if value} def _merge_values(old_value, new_value): if old_value is None: return new_value if isinstance(old_value, list) and isinstance(new_value, list): return old_value + new_value if isinstance(old_value, dict) and isinstance(new_value, dict): merged = dict(old_value) for key, value in new_value.items(): if key in merged: merged[key] = _merge_values(merged[key], value) else: merged[key] = value return merged if old_value in ("", None, [], {}): return new_value return old_value def process_document(file_path): model, processor, device = get_model() pages = load_input(file_path) page_results = [] for idx, image in enumerate(pages, start=1): print("first model has been called for",idx,"image") start = time.time() part_name = classify_page(image, model, processor, device) end = time.time() print("total time taken by the first model",end-start,"sec") if part_name == "UNKNOWN": page_results.append({ "page_number": idx, "status": "error", "part": "UNKNOWN", "raw_output": "", "parsed": None }) continue print("second model has been called for",idx,"image") start = time.time() result = extract_part_json(image, part_name, model, processor, device) end = time.time() print("total time taken by the second model",end-start,"sec") result["page_number"] = idx page_results.append(result) final_json = merge_page_results(page_results) return final_json