Spaces:
Paused
Paused
| # import torch | |
| # from model_loader import model, processor, device | |
| # from processor_utils import load_input | |
| # from prompt import get_prompt | |
| # import json | |
| # def process_document(image): | |
| # # images = load_input(file_path) | |
| # # image = images[0] | |
| # # print("Checking input type and no of pages in pdf") | |
| # # print(type(image)) | |
| # # print(type(images)) | |
| # # print(len(images)) | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image", "image": image}, | |
| # {"type": "text", "text": get_prompt()} | |
| # ] | |
| # } | |
| # ] | |
| # text = processor.apply_chat_template( | |
| # messages, | |
| # tokenize=False, # so that this can return string output | |
| # add_generation_prompt=True # if true it will add extra on start and end | |
| # ) | |
| # # print(f"The text of inference is {text}") | |
| # inputs = processor( | |
| # text=[text], | |
| # images=[image], | |
| # return_tensors="pt" | |
| # ).to(device) | |
| # # print(f"The inputs of inference is {inputs}") | |
| # output = model.generate( | |
| # **inputs, | |
| # max_new_tokens=1500, | |
| # do_sample=False, # if it is true there will be extra text with output | |
| # # temperature=0.1 # temp is not required | |
| # ) | |
| # # print(f"The output of inference is {output}") | |
| # generated_ids = output[0][inputs.input_ids.shape[-1]:] | |
| # # print(f"The generated_ids of inference is {generated_ids}") | |
| # # response = processor.decode( # past code | |
| # # generated_ids, | |
| # # skip_special_tokens=True | |
| # # ) | |
| # # return response.strip() | |
| # response = processor.decode( | |
| # generated_ids, | |
| # skip_special_tokens=True | |
| # ).strip() | |
| # # print(f"The response of inference is {response}") | |
| # # 🔥 FORCE JSON CLEANING | |
| # start = response.find("{") | |
| # end = response.rfind("}") + 1 | |
| # if start != -1 and end != -1: | |
| # response = response[start:end] | |
| # print(f"The type of response is before{response}") | |
| # try: | |
| # parsed = json.loads(response) | |
| # except: | |
| # parsed = { | |
| # "error":[ | |
| # response | |
| # ] | |
| # # "Invalid JSON", | |
| # # "raw": response | |
| # } | |
| # print(f"The type of response is after{response}") | |
| # return response | |
| # import json | |
| # from model_loader import get_model | |
| # from processor_utils import load_input | |
| # from prompt import get_part_classifier_prompt, get_part_prompt | |
| # def _run_model(image, prompt_text, model, processor, device): | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image", "image": image}, | |
| # {"type": "text", "text": prompt_text} | |
| # ] | |
| # } | |
| # ] | |
| # text = processor.apply_chat_template( | |
| # messages, | |
| # tokenize=False, | |
| # add_generation_prompt=True | |
| # ) | |
| # inputs = processor( | |
| # text=[text], | |
| # images=[image], | |
| # return_tensors="pt" | |
| # ).to(device) | |
| # output = model.generate( | |
| # **inputs, | |
| # max_new_tokens=400, | |
| # do_sample=False | |
| # ) | |
| # generated_ids = output[0][inputs.input_ids.shape[-1]:] | |
| # response = processor.decode(generated_ids, skip_special_tokens=True).strip() | |
| # return response | |
| # def _extract_json_block(text): | |
| # start = text.find("{") | |
| # end = text.rfind("}") + 1 | |
| # if start == -1 or end == 0: | |
| # return None | |
| # return text[start:end] | |
| # def classify_page(image, model, processor, device): | |
| # raw = _run_model(image, get_part_classifier_prompt(), model, processor, device) | |
| # raw = raw.strip().upper() | |
| # valid_parts = {"PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"} | |
| # for part in valid_parts: | |
| # if part in raw: | |
| # return part | |
| # return "UNKNOWN" | |
| # def extract_part_json(image, part_name, model, processor, device): | |
| # raw = _run_model(image, get_part_prompt(part_name), model, processor, device) | |
| # json_block = _extract_json_block(raw) | |
| # if not json_block: | |
| # return { | |
| # "status": "error", | |
| # "part": part_name, | |
| # "raw_output": raw, | |
| # "parsed": None | |
| # } | |
| # try: | |
| # parsed = json.loads(json_block) | |
| # return { | |
| # "status": "success", | |
| # "part": part_name, | |
| # "raw_output": raw, | |
| # "parsed": parsed | |
| # } | |
| # except json.JSONDecodeError: | |
| # return { | |
| # "status": "error", | |
| # "part": part_name, | |
| # "raw_output": raw, | |
| # "parsed": None | |
| # } | |
| # def process_document(file_path): | |
| # model, processor, device = get_model() | |
| # pages = load_input(file_path) | |
| # page_results = [] | |
| # for idx, image in enumerate(pages, start=1): | |
| # part_name = classify_page(image, model, processor, device) | |
| # if part_name == "UNKNOWN": | |
| # page_results.append({ | |
| # "page_number": idx, | |
| # "status": "error", | |
| # "part": "UNKNOWN", | |
| # "raw_output": "", | |
| # "parsed": None | |
| # }) | |
| # continue | |
| # result = extract_part_json(image, part_name, model, processor, device) | |
| # result["page_number"] = idx | |
| # page_results.append(result) | |
| # return { | |
| # "total_pages": len(page_results), | |
| # "pages": page_results | |
| # } | |
| import json | |
| from model_loader import get_model | |
| from processor_utils import load_input | |
| from prompt import get_part_classifier_prompt, get_part_prompt | |
| import time | |
| def _get_max_tokens(part_name): | |
| limits = { | |
| "CLASSIFIER": 20, | |
| "PART-1": 1200, | |
| "PART-2": 700, | |
| "PART-3": 1800, | |
| "PART-4": 500, | |
| "PART-5": 300, | |
| "PART-6": 100 | |
| } | |
| return limits.get(part_name, 600) | |
| def _clean_raw_text(text): | |
| text = text.strip() | |
| if text.startswith("```json"): | |
| text = text[len("```json"):].strip() | |
| elif text.startswith("```"): | |
| text = text[len("```"):].strip() | |
| if text.endswith("```"): | |
| text = text[:-3].strip() | |
| return text | |
| def _run_model(image, prompt_text, model, processor, device, max_new_tokens): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt_text} | |
| ] | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt" | |
| ).to(device) | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False | |
| ) | |
| generated_ids = output[0][inputs.input_ids.shape[-1]:] | |
| response = processor.decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| ).strip() | |
| return _clean_raw_text(response) | |
| def _extract_json_block(text): | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| if start == -1 or end == 0 or end <= start: | |
| return None | |
| return text[start:end] | |
| def classify_page(image, model, processor, device): | |
| raw = _run_model( | |
| image, | |
| get_part_classifier_prompt(), | |
| model, | |
| processor, | |
| device, | |
| max_new_tokens=_get_max_tokens("CLASSIFIER") | |
| ).upper() | |
| valid_parts = ["PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"] | |
| for part in valid_parts: | |
| if part in raw: | |
| return part | |
| return "UNKNOWN" | |
| def extract_part_json(image, part_name, model, processor, device): | |
| max_tokens = _get_max_tokens(part_name) | |
| raw = _run_model( | |
| image, | |
| get_part_prompt(part_name), | |
| model, | |
| processor, | |
| device, | |
| max_new_tokens=max_tokens | |
| ) | |
| json_block = _extract_json_block(raw) | |
| if json_block: | |
| try: | |
| parsed = json.loads(json_block) | |
| return { | |
| "status": "success", | |
| "part": part_name, | |
| "raw_output": raw, | |
| "parsed": parsed | |
| } | |
| except json.JSONDecodeError: | |
| pass | |
| # retry once with larger token budget | |
| retry_raw = _run_model( | |
| image, | |
| get_part_prompt(part_name), | |
| model, | |
| processor, | |
| device, | |
| max_new_tokens=max_tokens + 600 | |
| ) | |
| retry_json_block = _extract_json_block(retry_raw) | |
| if retry_json_block: | |
| try: | |
| parsed = json.loads(retry_json_block) | |
| return { | |
| "status": "success", | |
| "part": part_name, | |
| "raw_output": retry_raw, | |
| "parsed": parsed | |
| } | |
| except json.JSONDecodeError: | |
| pass | |
| return { | |
| "status": "error", | |
| "part": part_name, | |
| "raw_output": retry_raw if 'retry_raw' in locals() else raw, | |
| "parsed": None | |
| } | |
| def merge_page_results(page_results): | |
| final_json = {} | |
| for item in page_results: | |
| if item["status"] != "success" or not item["parsed"]: | |
| continue | |
| parsed = item["parsed"] | |
| for key, value in parsed.items(): | |
| final_json[key] = value | |
| return final_json | |
| def process_document(file_path): | |
| model, processor, device = get_model() | |
| pages = load_input(file_path) | |
| page_results = [] | |
| for idx, image in enumerate(pages, start=1): | |
| print("first model has been called for",idx,"image") | |
| start = time.time() | |
| part_name = classify_page(image, model, processor, device) | |
| end = time.time() | |
| print("total time taken by the first model",end-start,"sec") | |
| if part_name == "UNKNOWN": | |
| page_results.append({ | |
| "page_number": idx, | |
| "status": "error", | |
| "part": "UNKNOWN", | |
| "raw_output": "", | |
| "parsed": None | |
| }) | |
| continue | |
| print("second model has been called for",idx,"image") | |
| start = time.time() | |
| result = extract_part_json(image, part_name, model, processor, device) | |
| end = time.time() | |
| print("total time taken by the second model",end-start,"sec") | |
| print(result) | |
| result["page_number"] = idx | |
| page_results.append(result) | |
| final_json = merge_page_results(page_results) | |
| # return { | |
| # "final_json": final_json | |
| # # "total_pages": len(page_results), | |
| # # "pages": page_results | |
| # } | |
| return final_json |