Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END, START, add_messages | |
| from typing import TypedDict, Any, Annotated | |
| from PIL import Image | |
| from src.utils.utils_segment import ( | |
| preprocess, | |
| postprocess, | |
| extract_text, | |
| draw_bounding_boxes, | |
| ) | |
| from src.inference.segment_inference import model | |
| from src.config.llm import llm | |
| from src.prompt.promt import format_prompt | |
| from langchain_core.output_parsers import JsonOutputParser | |
| parser = JsonOutputParser() | |
| class State(TypedDict): | |
| image: Any | |
| image_origin: Any | |
| outputs_from_inference: Any | |
| text_extracted_from_ocr: Any | |
| threshold_confidence: float | |
| threshold_iou: float | |
| cropped_images: Any | |
| parser_output: bool | |
| image_with_bounding_boxes: Any | |
| _image: Annotated[Any, add_messages] | |
| crop_image: Any | |
| class N: | |
| PRE_PROCESS = "PRE_PROCESS" | |
| POST_PROCESS = "POST_PROCESS" | |
| INFERENCE = "INFERENCE" | |
| EXTRACT_TEXT_FROM_OCR = "EXTRACT_TEXT_FROM_OCR" | |
| PARSER_WITH_LLM = "PARSER_WITH_LLM" | |
| IMAGE_WITH_BOUNDING_BOXES = "IMAGE_WITH_BOUNDING_BOXES" | |
| workflow = StateGraph(State) | |
| def pre_process_fn(state: State): | |
| preprocess_img = preprocess(state["image_origin"]) | |
| print("preprocess_img", preprocess_img.shape) | |
| image_for_display = (preprocess_img[0] * 255).astype("uint8") | |
| image_for_display = image_for_display.transpose(1, 2, 0) | |
| image_show = Image.fromarray(image_for_display) | |
| return {"image": preprocess_img, "_image": image_show} | |
| def inference_fn(state: State): | |
| image = state["image"] | |
| outputs = model.run(None, {"images": image}) | |
| return {"outputs_from_inference": outputs} | |
| def post_process_fn(state: State): | |
| outputs = state["outputs_from_inference"] | |
| threshold_confidence = state["threshold_confidence"] | |
| threshold_iou = state["threshold_iou"] | |
| post_process_output = postprocess(outputs, threshold_confidence, threshold_iou) | |
| return { | |
| "outputs_from_inference": post_process_output, | |
| } | |
| def extract_text_from_ocr_fn(state: State): | |
| image_origin = state["image_origin"] | |
| output_from_inference = state["outputs_from_inference"] | |
| text = extract_text(output_from_inference, image_origin) | |
| return {"text_extracted_from_ocr": text} | |
| def draw_bounding_boxes_fn(state: State): | |
| image = state["image_origin"] | |
| outputs = state["outputs_from_inference"] | |
| image_with_bounding_boxes = draw_bounding_boxes(image, outputs) | |
| return {"image_with_bounding_boxes": image_with_bounding_boxes} | |
| def parser_output_fn(state: State): | |
| text_extracted_from_ocr = state["text_extracted_from_ocr"] | |
| chain = format_prompt | llm | parser | |
| response = chain.invoke({"user_input": text_extracted_from_ocr}) | |
| print(response) | |
| return {"parser_output": response} | |
| #NODE | |
| workflow.add_node(N.PRE_PROCESS, pre_process_fn) | |
| workflow.add_node(N.INFERENCE, inference_fn) | |
| workflow.add_node(N.POST_PROCESS, post_process_fn) | |
| workflow.add_node(N.EXTRACT_TEXT_FROM_OCR, extract_text_from_ocr_fn) | |
| workflow.add_node(N.IMAGE_WITH_BOUNDING_BOXES, draw_bounding_boxes_fn) | |
| workflow.add_node(N.PARSER_WITH_LLM, parser_output_fn) | |
| #EDGE | |
| workflow.add_edge(START, N.PRE_PROCESS) | |
| workflow.add_edge(N.PRE_PROCESS, N.INFERENCE) | |
| workflow.add_edge(N.INFERENCE, N.POST_PROCESS) | |
| workflow.add_edge(N.POST_PROCESS, N.IMAGE_WITH_BOUNDING_BOXES) | |
| workflow.add_edge(N.IMAGE_WITH_BOUNDING_BOXES, N.EXTRACT_TEXT_FROM_OCR) | |
| workflow.add_conditional_edges( | |
| N.EXTRACT_TEXT_FROM_OCR, | |
| lambda state: N.PARSER_WITH_LLM if state["parser_output"] else END, | |
| { | |
| N.PARSER_WITH_LLM: N.PARSER_WITH_LLM, | |
| END: END, | |
| }, | |
| ) | |
| workflow.add_edge(N.PARSER_WITH_LLM, END) | |
| app = workflow.compile() | |