File size: 3,642 Bytes
28b686e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from langgraph.graph import StateGraph, END, START, add_messages
from typing import TypedDict, Any, Annotated
from PIL import Image
from src.utils.utils_segment import (
    preprocess,
    postprocess,
    extract_text,
    draw_bounding_boxes,
)
from src.inference.segment_inference import model
from src.config.llm import llm
from src.prompt.promt import format_prompt
from langchain_core.output_parsers import JsonOutputParser

parser = JsonOutputParser()


class State(TypedDict):
    image: Any
    image_origin: Any
    outputs_from_inference: Any
    text_extracted_from_ocr: Any
    threshold_confidence: float
    threshold_iou: float
    cropped_images: Any
    parser_output: bool
    image_with_bounding_boxes: Any
    _image: Annotated[Any, add_messages]
    crop_image: Any


class N:
    PRE_PROCESS = "PRE_PROCESS"
    POST_PROCESS = "POST_PROCESS"
    INFERENCE = "INFERENCE"
    EXTRACT_TEXT_FROM_OCR = "EXTRACT_TEXT_FROM_OCR"
    PARSER_WITH_LLM = "PARSER_WITH_LLM"
    IMAGE_WITH_BOUNDING_BOXES = "IMAGE_WITH_BOUNDING_BOXES"


workflow = StateGraph(State)


def pre_process_fn(state: State):
    preprocess_img = preprocess(state["image_origin"])
    print("preprocess_img", preprocess_img.shape)
    image_for_display = (preprocess_img[0] * 255).astype("uint8")
    image_for_display = image_for_display.transpose(1, 2, 0)
    image_show = Image.fromarray(image_for_display)
    return {"image": preprocess_img, "_image": image_show}


def inference_fn(state: State):
    image = state["image"]
    outputs = model.run(None, {"images": image})
    return {"outputs_from_inference": outputs}


def post_process_fn(state: State):
    outputs = state["outputs_from_inference"]
    threshold_confidence = state["threshold_confidence"]
    threshold_iou = state["threshold_iou"]
    post_process_output = postprocess(outputs, threshold_confidence, threshold_iou)

    return {
        "outputs_from_inference": post_process_output,
    }


def extract_text_from_ocr_fn(state: State):
    image_origin = state["image_origin"]
    output_from_inference = state["outputs_from_inference"]

    text = extract_text(output_from_inference, image_origin)
    return {"text_extracted_from_ocr": text}


def draw_bounding_boxes_fn(state: State):
    image = state["image_origin"]
    outputs = state["outputs_from_inference"]
    image_with_bounding_boxes = draw_bounding_boxes(image, outputs)
    return {"image_with_bounding_boxes": image_with_bounding_boxes}


def parser_output_fn(state: State):
    text_extracted_from_ocr = state["text_extracted_from_ocr"]
    chain = format_prompt | llm | parser
    response = chain.invoke({"user_input": text_extracted_from_ocr})
    print(response)
    return {"parser_output": response}

#NODE
workflow.add_node(N.PRE_PROCESS, pre_process_fn)
workflow.add_node(N.INFERENCE, inference_fn)
workflow.add_node(N.POST_PROCESS, post_process_fn)
workflow.add_node(N.EXTRACT_TEXT_FROM_OCR, extract_text_from_ocr_fn)
workflow.add_node(N.IMAGE_WITH_BOUNDING_BOXES, draw_bounding_boxes_fn)
workflow.add_node(N.PARSER_WITH_LLM, parser_output_fn)

#EDGE
workflow.add_edge(START, N.PRE_PROCESS)
workflow.add_edge(N.PRE_PROCESS, N.INFERENCE)
workflow.add_edge(N.INFERENCE, N.POST_PROCESS)
workflow.add_edge(N.POST_PROCESS, N.IMAGE_WITH_BOUNDING_BOXES)
workflow.add_edge(N.IMAGE_WITH_BOUNDING_BOXES, N.EXTRACT_TEXT_FROM_OCR)
workflow.add_conditional_edges(
    N.EXTRACT_TEXT_FROM_OCR,
    lambda state: N.PARSER_WITH_LLM if state["parser_output"] else END,
    {
        N.PARSER_WITH_LLM: N.PARSER_WITH_LLM,
        END: END,
    },
)
workflow.add_edge(N.PARSER_WITH_LLM, END)


app = workflow.compile()