wolf1997 commited on
Commit
264fc54
·
verified ·
1 Parent(s): 4b77121

Update receipt_gen_agent.py

Browse files
Files changed (1) hide show
  1. receipt_gen_agent.py +259 -259
receipt_gen_agent.py CHANGED
@@ -1,260 +1,260 @@
1
- from langchain_core.output_parsers import JsonOutputParser
2
- from langchain_core.prompts import PromptTemplate
3
- from dotenv import load_dotenv
4
- import os
5
- from typing import List
6
- from typing_extensions import TypedDict
7
- from langchain_core.messages import HumanMessage
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain.output_parsers import RetryOutputParser
10
- from langgraph.graph import StateGraph, START, END
11
- import base64
12
- from IPython.display import Image as img, display
13
- from langchain_core.runnables.graph import MermaidDrawMethod
14
- from langgraph.checkpoint.memory import MemorySaver
15
- import json
16
- from pydantic import BaseModel, Field
17
- from io import BytesIO
18
- load_dotenv()
19
- GEMINI_API_KEY=os.getenv('google_api_key')
20
-
21
-
22
- GEMINI_MODEL='gemini-2.0-flash'
23
- llm = ChatGoogleGenerativeAI(google_api_key=GEMINI_API_KEY, model=GEMINI_MODEL, temperature=0.3)
24
-
25
- from os import listdir
26
- from os.path import isfile, join
27
-
28
-
29
- class State(TypedDict):
30
- prompt: str
31
- image_number: int
32
- image_data: json
33
- image_byte: str
34
- eval: dict
35
- n_retries:int
36
- image_name: str
37
- image_data_list: list
38
-
39
-
40
- def generate_data_node(state:State):
41
- class Items(BaseModel):
42
- name: str = Field(description='the name of the item')
43
- price : float = Field(description='the price of the item')
44
- quantity: int = Field(description='the quantity of the item')
45
-
46
- class Form(BaseModel):
47
- loc_name: str = Field(description='the name of the location if no name put empty str')
48
- address: str = Field(description='the address of the location if no location put empty str')
49
- date: str = Field(description='the date if no date put empty str')
50
- time: str = Field(description='the time if no time put empty str')
51
- items: List[Items] = Field(description= 'list of the items if no items put empty list')
52
- subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
53
- tax: float = Field(description='the tax, if no tax put 0')
54
- total: float = Field(description='the total amount if no total amount put 0')
55
-
56
-
57
- parser=JsonOutputParser(pydantic_object=Form)
58
- instruction=parser.get_format_instructions()
59
- message = HumanMessage(
60
- content=[
61
- {"type": "text", "text": f"{state.get('prompt')}"+'\n\n'+ instruction},
62
- {
63
- "type": "image_url",
64
- "image_url": {"url": f"data:image/jpeg;base64,{state.get('image_byte')}"},
65
- },
66
- ],
67
- )
68
- response=llm.invoke([message])
69
- try:
70
- response=parser.parse(response.content)
71
- return {'image_data':response}
72
- except:
73
- prompt = PromptTemplate(
74
- template="Answer the user query.\n{format_instructions}\n{query}\n",
75
- input_variables=["query"],
76
- partial_variables={"format_instructions": parser.get_format_instructions()},
77
- )
78
- retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
79
- prompt_value=prompt.format_prompt(query=f'{state.get('prompt')}')
80
- response=retry_parser.parse_with_prompt(response.content, prompt_value)
81
- return {'image_data':response}
82
-
83
- def evaluate_node(state:State):
84
-
85
-
86
- class Decision(BaseModel):
87
- decision: str = Field(description='good or modify if changes have to be made')
88
- comment: str = Field(description='the changes to make')
89
-
90
- parser=JsonOutputParser(pydantic_object=Decision)
91
- prompt = PromptTemplate(
92
- template="Answer the user query.\n{format_instructions}\n{query}\n",
93
- input_variables=["query"],
94
- partial_variables={"format_instructions": parser.get_format_instructions()},
95
- )
96
- data=state.get('image_data')
97
- query=f" is the {data} correct and makes sense tell the llm what to change, ignore missing data, don't make it up, no explanation or decription needed"
98
- chain = prompt | llm
99
- response=chain.invoke({'query':query})
100
- try:
101
- response=parser.parse(response.content)
102
- except:
103
-
104
- retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
105
-
106
- prompt_value = prompt.format_prompt(query=query)
107
- response=retry_parser.parse_with_prompt(response.content, prompt_value)
108
- return {'eval': response}
109
-
110
-
111
- def data_editor_node(state:State):
112
- class Items(BaseModel):
113
- name: str = Field(description='the name of the item')
114
- price : float = Field(description='the price of the item')
115
- quantity: int = Field(description='the quantity of the item')
116
-
117
- class Form(BaseModel):
118
- loc_name: str = Field(description='the name of the location if no name put empty str')
119
- address: str = Field(description='the address of the location if no location put empty str')
120
- date: str = Field(description='the date if no date put empty str')
121
- time: str = Field(description='the time if no time put empty str')
122
- items: List[Items] = Field(description= 'list of the items if no items put empty list')
123
- subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
124
- tax: float = Field(description='the tax, if no tax put 0')
125
- total: float = Field(description='the total amount if no total amount put 0')
126
-
127
-
128
- parser=JsonOutputParser(pydantic_object=Form)
129
- prompt = PromptTemplate(
130
- template="Answer the user query.\n{format_instructions}\n{query}\n",
131
- input_variables=["query"],
132
- partial_variables={"format_instructions": parser.get_format_instructions()},
133
- )
134
-
135
-
136
- data=state.get('image_data')
137
- query=f"modify this dict: {data} based on these comments {state.get('eval').get('comment')}, return a json"
138
- chain = prompt | llm
139
- response=chain.invoke({'query':query})
140
- try:
141
- response=parser.parse(response.content)
142
- except:
143
-
144
- retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
145
-
146
- prompt_value = prompt.format_prompt(query=query)
147
- response=retry_parser.parse_with_prompt(response.content, prompt_value)
148
- return {'image_data': response,
149
- 'n_retries':state.get('n_retries')+1}
150
-
151
-
152
- def should_continue(state:State)-> str:
153
- """
154
- Determine whether the research process should continue based on the current state.
155
-
156
- Args:
157
- state: The current state of the agent.
158
-
159
- Returns:
160
- str: The next state to transition to ("to_add_data", "to_prompt_editor").
161
- """
162
- eval=state.get('eval').get('decision')
163
- if eval =='good':
164
- return 'to_add_data'
165
-
166
- elif eval =='modify' and state.get('n_retries')<2:
167
- return 'to_data_editor'
168
- else:
169
- return 'to_add_data'
170
-
171
-
172
- def add_data_node(state:State):
173
- img_number=state.get('image_number')
174
- return {
175
- 'n_retries':0,
176
-
177
- 'image_name':f'{img_number}_new_receipt.jpg'}
178
-
179
- class receipt_agent:
180
- def __init__(self):
181
- self.agent=self._setup()
182
- def _setup(self):
183
-
184
- agent_builder=StateGraph(State)
185
- agent_builder.add_node('generate_data',generate_data_node)
186
- agent_builder.add_node('evaluate',evaluate_node)
187
- agent_builder.add_node('add_data',add_data_node)
188
- agent_builder.add_node('data_editor',data_editor_node)
189
-
190
- agent_builder.add_edge(START,'generate_data')
191
- agent_builder.add_edge('generate_data','evaluate')
192
- # agent_builder.add_edge('evaluate',END)
193
- agent_builder.add_conditional_edges('evaluate', should_continue, {'to_data_editor':'data_editor', 'to_add_data':'add_data'},)
194
- agent_builder.add_edge('data_editor','evaluate')
195
- agent_builder.add_edge('add_data', END)
196
-
197
-
198
- checkpointer=MemorySaver()
199
-
200
- agent=agent_builder.compile(checkpointer=checkpointer)
201
- return agent
202
-
203
-
204
- def display_graph(self):
205
- return display(
206
- img(
207
- self.agent.get_graph().draw_mermaid_png(
208
- draw_method=MermaidDrawMethod.API,
209
- )
210
- )
211
- )
212
-
213
- def get_state(self, state_val:str):
214
- config = {"configurable": {"thread_id": "1"}}
215
- return self.agent.get_state(config).values[state_val]
216
-
217
- def receipt_gen(self,image):
218
- config = {"configurable": {"thread_id": "1"}}
219
- buffered=BytesIO()
220
-
221
- image.save(buffered, format='JPEG')
222
- image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
223
-
224
- data_list = [f for f in listdir('new_receipt_data') if isfile(join('new_receipt_data', f))]
225
- if not data_list:
226
- data_list=[]
227
- else:
228
- with open(f'new_receipt_data/{data_list[0]}', 'r') as openfile:
229
- # Reading from json file
230
- data_list = json.load(openfile)
231
-
232
- response=self.agent.invoke({'prompt':'analyse this receipt and list the items, return a json',
233
- 'n_retries':0,
234
- 'image_number':len(data_list),
235
- 'image_byte': image_data,
236
- 'image_data_list':data_list}, config)
237
-
238
- image_data=response.get('image_data')
239
- return image_data
240
-
241
- def update_state(self, values:dict):
242
- config = {"configurable": {"thread_id": "1"}}
243
- return self.agent.update_state(config,values=values)
244
-
245
- def confirm(self,image_data):
246
- config = {"configurable": {"thread_id": "1"}}
247
- if image_data:
248
- data_list=self.agent.get_state(config).values['image_data_list']
249
- img_number=self.agent.get_state(config).values['image_number']
250
- image_name=self.agent.get_state(config).values['image_name']
251
- if not data_list:
252
- data_list=[]
253
- data_list.append({'receipt_name':f'{img_number}_new_receipt.jpg',
254
- 'receipt_data':image_data})
255
- self.agent.update_state(config,values={'image_data_list':data_list})
256
-
257
-
258
- return data_list,image_name
259
-
260
 
 
1
+ from langchain_core.output_parsers import JsonOutputParser
2
+ from langchain_core.prompts import PromptTemplate
3
+ from dotenv import load_dotenv
4
+ import os
5
+ from typing import List
6
+ from typing_extensions import TypedDict
7
+ from langchain_core.messages import HumanMessage
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain.output_parsers import RetryOutputParser
10
+ from langgraph.graph import StateGraph, START, END
11
+ import base64
12
+ from IPython.display import Image as img, display
13
+ from langchain_core.runnables.graph import MermaidDrawMethod
14
+ from langgraph.checkpoint.memory import MemorySaver
15
+ import json
16
+ from pydantic import BaseModel, Field
17
+ from io import BytesIO
18
+ load_dotenv()
19
+ GEMINI_API_KEY=os.getenv('google_api_key')
20
+
21
+
22
+ GEMINI_MODEL='gemini-2.0-flash'
23
+ llm = ChatGoogleGenerativeAI(google_api_key=GEMINI_API_KEY, model=GEMINI_MODEL, temperature=0.3)
24
+
25
+ from os import listdir
26
+ from os.path import isfile, join
27
+
28
+
29
+ class State(TypedDict):
30
+ prompt: str
31
+ image_number: int
32
+ image_data: json
33
+ image_byte: str
34
+ eval: dict
35
+ n_retries:int
36
+ image_name: str
37
+ image_data_list: list
38
+
39
+
40
+ def generate_data_node(state:State):
41
+ class Items(BaseModel):
42
+ name: str = Field(description='the name of the item')
43
+ price : float = Field(description='the price of the item')
44
+ quantity: int = Field(description='the quantity of the item')
45
+
46
+ class Form(BaseModel):
47
+ loc_name: str = Field(description='the name of the location if no name put empty str')
48
+ address: str = Field(description='the address of the location if no location put empty str')
49
+ date: str = Field(description='the date if no date put empty str')
50
+ time: str = Field(description='the time if no time put empty str')
51
+ items: List[Items] = Field(description= 'list of the items if no items put empty list')
52
+ subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
53
+ tax: float = Field(description='the tax, if no tax put 0')
54
+ total: float = Field(description='the total amount if no total amount put 0')
55
+
56
+
57
+ parser=JsonOutputParser(pydantic_object=Form)
58
+ instruction=parser.get_format_instructions()
59
+ message = HumanMessage(
60
+ content=[
61
+ {"type": "text", "text": f"{state.get('prompt')}"+'\n\n'+ instruction},
62
+ {
63
+ "type": "image_url",
64
+ "image_url": {"url": f"data:image/jpeg;base64,{state.get('image_byte')}"},
65
+ },
66
+ ],
67
+ )
68
+ response=llm.invoke([message])
69
+ try:
70
+ response=parser.parse(response.content)
71
+ return {'image_data':response}
72
+ except:
73
+ prompt = PromptTemplate(
74
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
75
+ input_variables=["query"],
76
+ partial_variables={"format_instructions": parser.get_format_instructions()},
77
+ )
78
+ retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
79
+ prompt_value=prompt.format_prompt(query=f"{state.get('prompt')}")
80
+ response=retry_parser.parse_with_prompt(response.content, prompt_value)
81
+ return {'image_data':response}
82
+
83
+ def evaluate_node(state:State):
84
+
85
+
86
+ class Decision(BaseModel):
87
+ decision: str = Field(description='good or modify if changes have to be made')
88
+ comment: str = Field(description='the changes to make')
89
+
90
+ parser=JsonOutputParser(pydantic_object=Decision)
91
+ prompt = PromptTemplate(
92
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
93
+ input_variables=["query"],
94
+ partial_variables={"format_instructions": parser.get_format_instructions()},
95
+ )
96
+ data=state.get('image_data')
97
+ query=f" is the {data} correct and makes sense tell the llm what to change, ignore missing data, don't make it up, no explanation or decription needed"
98
+ chain = prompt | llm
99
+ response=chain.invoke({'query':query})
100
+ try:
101
+ response=parser.parse(response.content)
102
+ except:
103
+
104
+ retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
105
+
106
+ prompt_value = prompt.format_prompt(query=query)
107
+ response=retry_parser.parse_with_prompt(response.content, prompt_value)
108
+ return {'eval': response}
109
+
110
+
111
+ def data_editor_node(state:State):
112
+ class Items(BaseModel):
113
+ name: str = Field(description='the name of the item')
114
+ price : float = Field(description='the price of the item')
115
+ quantity: int = Field(description='the quantity of the item')
116
+
117
+ class Form(BaseModel):
118
+ loc_name: str = Field(description='the name of the location if no name put empty str')
119
+ address: str = Field(description='the address of the location if no location put empty str')
120
+ date: str = Field(description='the date if no date put empty str')
121
+ time: str = Field(description='the time if no time put empty str')
122
+ items: List[Items] = Field(description= 'list of the items if no items put empty list')
123
+ subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
124
+ tax: float = Field(description='the tax, if no tax put 0')
125
+ total: float = Field(description='the total amount if no total amount put 0')
126
+
127
+
128
+ parser=JsonOutputParser(pydantic_object=Form)
129
+ prompt = PromptTemplate(
130
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
131
+ input_variables=["query"],
132
+ partial_variables={"format_instructions": parser.get_format_instructions()},
133
+ )
134
+
135
+
136
+ data=state.get('image_data')
137
+ query=f"modify this dict: {data} based on these comments {state.get('eval').get('comment')}, return a json"
138
+ chain = prompt | llm
139
+ response=chain.invoke({'query':query})
140
+ try:
141
+ response=parser.parse(response.content)
142
+ except:
143
+
144
+ retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
145
+
146
+ prompt_value = prompt.format_prompt(query=query)
147
+ response=retry_parser.parse_with_prompt(response.content, prompt_value)
148
+ return {'image_data': response,
149
+ 'n_retries':state.get('n_retries')+1}
150
+
151
+
152
+ def should_continue(state:State)-> str:
153
+ """
154
+ Determine whether the research process should continue based on the current state.
155
+
156
+ Args:
157
+ state: The current state of the agent.
158
+
159
+ Returns:
160
+ str: The next state to transition to ("to_add_data", "to_prompt_editor").
161
+ """
162
+ eval=state.get('eval').get('decision')
163
+ if eval =='good':
164
+ return 'to_add_data'
165
+
166
+ elif eval =='modify' and state.get('n_retries')<2:
167
+ return 'to_data_editor'
168
+ else:
169
+ return 'to_add_data'
170
+
171
+
172
+ def add_data_node(state:State):
173
+ img_number=state.get('image_number')
174
+ return {
175
+ 'n_retries':0,
176
+
177
+ 'image_name':f'{img_number}_new_receipt.jpg'}
178
+
179
+ class receipt_agent:
180
+ def __init__(self):
181
+ self.agent=self._setup()
182
+ def _setup(self):
183
+
184
+ agent_builder=StateGraph(State)
185
+ agent_builder.add_node('generate_data',generate_data_node)
186
+ agent_builder.add_node('evaluate',evaluate_node)
187
+ agent_builder.add_node('add_data',add_data_node)
188
+ agent_builder.add_node('data_editor',data_editor_node)
189
+
190
+ agent_builder.add_edge(START,'generate_data')
191
+ agent_builder.add_edge('generate_data','evaluate')
192
+ # agent_builder.add_edge('evaluate',END)
193
+ agent_builder.add_conditional_edges('evaluate', should_continue, {'to_data_editor':'data_editor', 'to_add_data':'add_data'},)
194
+ agent_builder.add_edge('data_editor','evaluate')
195
+ agent_builder.add_edge('add_data', END)
196
+
197
+
198
+ checkpointer=MemorySaver()
199
+
200
+ agent=agent_builder.compile(checkpointer=checkpointer)
201
+ return agent
202
+
203
+
204
+ def display_graph(self):
205
+ return display(
206
+ img(
207
+ self.agent.get_graph().draw_mermaid_png(
208
+ draw_method=MermaidDrawMethod.API,
209
+ )
210
+ )
211
+ )
212
+
213
+ def get_state(self, state_val:str):
214
+ config = {"configurable": {"thread_id": "1"}}
215
+ return self.agent.get_state(config).values[state_val]
216
+
217
+ def receipt_gen(self,image):
218
+ config = {"configurable": {"thread_id": "1"}}
219
+ buffered=BytesIO()
220
+
221
+ image.save(buffered, format='JPEG')
222
+ image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
223
+
224
+ data_list = [f for f in listdir('new_receipt_data') if isfile(join('new_receipt_data', f))]
225
+ if not data_list:
226
+ data_list=[]
227
+ else:
228
+ with open(f'new_receipt_data/{data_list[0]}', 'r') as openfile:
229
+ # Reading from json file
230
+ data_list = json.load(openfile)
231
+
232
+ response=self.agent.invoke({'prompt':'analyse this receipt and list the items, return a json',
233
+ 'n_retries':0,
234
+ 'image_number':len(data_list),
235
+ 'image_byte': image_data,
236
+ 'image_data_list':data_list}, config)
237
+
238
+ image_data=response.get('image_data')
239
+ return image_data
240
+
241
+ def update_state(self, values:dict):
242
+ config = {"configurable": {"thread_id": "1"}}
243
+ return self.agent.update_state(config,values=values)
244
+
245
+ def confirm(self,image_data):
246
+ config = {"configurable": {"thread_id": "1"}}
247
+ if image_data:
248
+ data_list=self.agent.get_state(config).values['image_data_list']
249
+ img_number=self.agent.get_state(config).values['image_number']
250
+ image_name=self.agent.get_state(config).values['image_name']
251
+ if not data_list:
252
+ data_list=[]
253
+ data_list.append({'receipt_name':f'{img_number}_new_receipt.jpg',
254
+ 'receipt_data':image_data})
255
+ self.agent.update_state(config,values={'image_data_list':data_list})
256
+
257
+
258
+ return data_list,image_name
259
+
260