UbaidMajied commited on
Commit
373bc71
·
verified ·
1 Parent(s): a444054

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +432 -0
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from typing import TypedDict, Annotated, List
3
+ import operator
4
+ from langgraph.checkpoint.sqlite import SqliteSaver
5
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage
6
+ from langchain_core.runnables import chain
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.pydantic_v1 import BaseModel, Field
9
+ from langchain_core.output_parsers import JsonOutputParser
10
+ import base64
11
+ from langchain.chains import TransformChain
12
+ from google.colab import userdata
13
+ from IPython import display
14
+ import gradio as gr
15
+ from openai import OpenAI
16
+ from pydub import AudioSegment
17
+ from pathlib import Path
18
+ import os
19
+
20
+
21
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
22
+
23
+ def encode_image(image_path: str) -> str:
24
+ """Return the binary contents of a file as a base64 encoded string."""
25
+ with open(image_path, "rb") as image_file:
26
+ return base64.b64encode(image_file.read()).decode('utf-8')
27
+
28
+
29
+ def load_image(inputs: dict) -> dict:
30
+ """Load image from file and encode it as base64."""
31
+ image_path = inputs["image_path"]
32
+ image_base64 = encode_image(image_path)
33
+ return {"image": image_base64}
34
+
35
+
36
+ def get_open_ai_api_key() -> str:
37
+ return userdata.get('OPEN_AI_API_KEY')
38
+
39
+
40
+ client = OpenAI()
41
+
42
+ def encode_image(image_path):
43
+ with open(image_path, "rb") as image_file:
44
+ return base64.b64encode(image_file.read()).decode('utf-8')
45
+
46
+
47
+ def load_image(inputs: dict) -> dict:
48
+ """Load image from file and encode it as base64."""
49
+ image_path = inputs["image_path"]
50
+ image_base64 = encode_image(image_path)
51
+ return {"image": image_base64}
52
+
53
+
54
+ class GenerateQuestion(BaseModel):
55
+ """Information about an image."""
56
+ question: str = Field(description= "A single, open-ended question to start the convesation")
57
+
58
+ QUESTION_PARSER = JsonOutputParser(pydantic_object=GenerateQuestion)
59
+
60
+
61
+ class GenerateQuestion2(BaseModel):
62
+ """Information about an image and the user's responses."""
63
+ acknowledgement_followback_question: str = Field(description= "An acknowledgement to user's most recent input and a follow-up question to gather more information about the photograph.")
64
+
65
+ QUESTION_PARSER_2 = JsonOutputParser(pydantic_object=GenerateQuestion2)
66
+
67
+
68
+ class GenerateQuestion3(BaseModel):
69
+ """Information about an image and the user's responses."""
70
+ acknowledgement_followback_question: str = Field(description= "An acknowledgement to user's most recent input and a follow-up question to expand on the conversation.")
71
+
72
+ QUESTION_PARSER_3 = JsonOutputParser(pydantic_object=GenerateQuestion3)
73
+
74
+ class GenerateCritique(BaseModel):
75
+ """Information about an image."""
76
+ critique: str = Field(description= "A Critique")
77
+ question: str = Field(description= "A revised reply and follow up question, if necessary")
78
+
79
+ CRITIQUE_PARSER = JsonOutputParser(pydantic_object=GenerateCritique)
80
+
81
+ @chain
82
+ def image_model(inputs: dict) -> str | list[str] | dict:
83
+ """Invoke model with image and prompt."""
84
+ model = ChatOpenAI(temperature=inputs["temperature"], model="gpt-4o", max_tokens=1024)
85
+ msg = model.invoke(
86
+ [HumanMessage(
87
+ content=[
88
+ {"type": "text", "text": inputs["prompt"]},
89
+ {"type": "text", "text": inputs["parser"].get_format_instructions()},
90
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}"}},
91
+ ])]
92
+ )
93
+ return msg.content
94
+
95
+
96
+ load_image_chain = TransformChain(
97
+ input_variables=["image_path"],
98
+ output_variables=["image"],
99
+ transform=load_image
100
+ )
101
+
102
+
103
+ def fast_thinking(image_path: str, prompt: str, parser, temperature) -> dict:
104
+ # vision_chain = load_image_chain | image_model | parser
105
+ # return vision_chain.invoke({'image_path': f'{image_path}', 'prompt': prompt, 'parser':parser, "temperature": temperature})
106
+ encoded_image = encode_image(image_path)
107
+ response = client.chat.completions.create(
108
+ model="gpt-4o",
109
+ messages=[
110
+ {
111
+ "role": "user",
112
+ "content": [
113
+
114
+ {
115
+ "type": "image_url",
116
+ "image_url": {
117
+ "url": f"data:image/jpeg;base64,{encoded_image}",
118
+ "detail": "auto"
119
+ }
120
+ },
121
+ {
122
+ "type": "text",
123
+ "text": prompt
124
+ }
125
+ ]
126
+ },
127
+ ],
128
+ temperature= temperature,
129
+ max_tokens=1024,
130
+ )
131
+ return response.choices[0].message.content
132
+
133
+ def get_story(image_path: str, prompt: str, temperature) -> dict:
134
+ # vision_chain = load_image_chain | image_model | parser
135
+ # return vision_chain.invoke({'image_path': f'{image_path}', 'prompt': prompt, 'parser':parser, "temperature": temperature})
136
+ encoded_image = encode_image(image_path)
137
+ response = client.chat.completions.create(
138
+ model="gpt-4o",
139
+ messages=[
140
+ {
141
+ "role": "user",
142
+ "content": [
143
+
144
+ {
145
+ "type": "image_url",
146
+ "image_url": {
147
+ "url": f"data:image/jpeg;base64,{encoded_image}",
148
+ "detail": "auto"
149
+ }
150
+ },
151
+ {
152
+ "type": "text",
153
+ "text": prompt
154
+ }
155
+ ]
156
+ },
157
+ ],
158
+ temperature= temperature,
159
+ max_tokens=1024,
160
+ )
161
+ return response.choices[0].message.content
162
+
163
+
164
+
165
+ class AgentState(TypedDict):
166
+ image_path: str
167
+ prompt:str
168
+ critique_prompt:str
169
+ question1: str
170
+ question2: str
171
+ critique: str
172
+ temperature: float
173
+
174
+
175
+ def generate_question_node1(state: AgentState):
176
+ res = fast_thinking(state["image_path"], state["prompt"], QUESTION_PARSER_2, state["temperature"])
177
+ return {"question1": res["question"]}
178
+
179
+
180
+ def question_critique_node(state: AgentState):
181
+ critique_prompt = state["critique_prompt"].format(question=state["question1"])
182
+ res = fast_thinking(state["image_path"],critique_prompt, CRITIQUE_PARSER, state["temperature"])
183
+ return {"critique": res["critique"], "question2": res["question"]}
184
+
185
+
186
+ # builder = StateGraph(AgentState)
187
+ # builder.add_node("question_generator1", generate_question_node1)
188
+ # builder.add_node("question_critique", question_critique_node)
189
+ # builder.add_edge("question_generator1", "question_critique")
190
+ # builder.set_entry_point("question_generator1")
191
+ # graph = builder.compile()
192
+ # display.Image(graph.get_graph().draw_png())
193
+
194
+
195
+ def slow_thinking(image_path: str, prompt:str, critique_prompt:str, temperature):
196
+ builder = StateGraph(AgentState)
197
+ builder.add_node("question_generator1", generate_question_node1)
198
+ builder.add_node("question_critique", question_critique_node)
199
+ builder.add_edge("question_generator1", "question_critique")
200
+ builder.set_entry_point("question_generator1")
201
+ graph = builder.compile()
202
+ final_state = graph.invoke(
203
+ {
204
+ 'image_path': image_path,
205
+ 'prompt':prompt,
206
+ 'critique_prompt': critique_prompt,
207
+ 'temperature': temperature
208
+
209
+ }, config={"configurable": {"thread_id": 1}})
210
+ return final_state
211
+
212
+
213
+
214
+ def transform_text_to_speech(text: str):
215
+ # Generate speech from transcription
216
+ speech_file_path_mp3 = Path.cwd() / f"speech.mp3"
217
+ speech_file_path_wav = Path.cwd() / f"speech.wav"
218
+ response = client.audio.speech.create (
219
+ model="tts-1",
220
+ voice="onyx",
221
+ input=text
222
+ )
223
+
224
+ with open(speech_file_path_mp3, "wb") as f:
225
+ f.write(response.content)
226
+
227
+ # Convert mp3 to wav
228
+ audio = AudioSegment.from_mp3(speech_file_path_mp3)
229
+ audio.export(speech_file_path_wav, format="wav")
230
+
231
+ # Read the audio file and encode it to base64
232
+ with open(speech_file_path_wav, "rb") as audio_file:
233
+ audio_data = audio_file.read()
234
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
235
+
236
+ # Create an HTML audio player with autoplay
237
+ audio_html = f"""
238
+ <audio controls autoplay>
239
+ <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
240
+ Your browser does not support the audio element.
241
+ </audio>
242
+ """
243
+ return audio_html
244
+
245
+ CONVERSATION_STARTER_PROMPT = """
246
+ ### Role
247
+ {role}
248
+
249
+ ### Context
250
+ The user is an older person who has uploaded a photograph. Your goal is to start a meaningful and inviting conversation about the photo.
251
+
252
+ ### Objective
253
+ Ask a simple first question that encourages the user to start talking about the photograph based on the below rules.
254
+
255
+ ### Guidelines
256
+ Follow these rules while generating the question:
257
+ {rules}
258
+
259
+ ### Output
260
+ Provide:
261
+ - A single, open-ended question based on the above rules.
262
+
263
+ Note: Output should be in 1 to 2 lines. Please don't generate anything else.
264
+ """
265
+
266
+ CONVERSATION_STARTER2_PROMPT = """
267
+ ### Role
268
+ {role}
269
+
270
+ ### Context
271
+ The user is an older person who has uploaded a photo, and you are at the start of a conversation about it.
272
+ Here is the conversation history about the photo between the user and you (Good friend):
273
+ {history}
274
+
275
+ ### Objective
276
+ Respond to user's most recent input in the conversation history above and a follow-up question generated based on below rules.
277
+
278
+ ### Guidelines
279
+ Follow these rules while generating the follow up question:
280
+ {rules}
281
+
282
+ ### Output
283
+ Provide:
284
+ - Respond to user's most recent input in the conversation history above and a follow-up question generated based on above rules.
285
+
286
+ Note: Output should be in 2 to 3 lines. Please don't generate anything else.
287
+ """
288
+
289
+
290
+ CONVERSATION_EXPANDING_PROMPT = """
291
+ ### Role
292
+ {role}
293
+
294
+ ### Context
295
+ The user is an older person who has uploaded a photo, and you are in the middle of a conversation about it.
296
+ Here is the conversation history about the photo between the user and you (Good friend), reflecting the ongoing dialogue:
297
+ {history}
298
+
299
+ ### Objective
300
+ Respond to user's most recent input in the conversation history above and a follow-up question generated based on below rules
301
+
302
+ ### Guidelines
303
+ Follow these rules while generating the follow up question:
304
+ {rules}
305
+
306
+ ### Output
307
+ Provide:
308
+ - Respond to user's most recent input in the conversation history above and a follow-up question generated based on above rules.
309
+
310
+ Note: Output should be in 2 to 3 lines. Please don't generate anything else.
311
+ """
312
+
313
+
314
+ generate_story_prompt = """
315
+ Given a photograph uploaded by the user and a conversation between a good friend and the user about the photograph:
316
+
317
+ {conversation}
318
+
319
+ Instructions:
320
+ 1. Create a short story that captures the essence of the conversation about the photograph.
321
+ 2. Do not invent new details—base the story entirely on the provided conversation.
322
+
323
+ Provide:
324
+ 1. A concise story in three sentences.
325
+
326
+ Note: Please generated only story.
327
+ """
328
+
329
+
330
+
331
+ memory = ""
332
+ iter = 1
333
+ image_path = ""
334
+
335
+ def pred(image_input, role, conversation_starter_prompt_rules, conversation_starter2_prompt_rules, conversation_expanding_prompt_rules, temperature, reply):
336
+ global memory
337
+ global iter
338
+ global image_path
339
+ if image_path != image_input:
340
+ image_path = image_input
341
+ iter = 1
342
+ memory = ""
343
+
344
+ # Fast Thinking
345
+ # if iter <= 50:
346
+ if iter == 1:
347
+ prompt = CONVERSATION_STARTER_PROMPT.format(role = role, rules=conversation_starter_prompt_rules)
348
+ res = fast_thinking(image_path, prompt, QUESTION_PARSER, temperature)
349
+ question = res
350
+ memory += "\n" + "Good Friend: "+ question
351
+ iter += 1
352
+ return "Fast", question, transform_text_to_speech(question)
353
+ if iter > 1 and iter <= 3:
354
+ prompt = CONVERSATION_STARTER2_PROMPT.format(role = role, history=memory,rules = conversation_starter2_prompt_rules)
355
+ res = fast_thinking(image_path, prompt, QUESTION_PARSER_2, temperature)
356
+ acknowledgement_followback_question = res
357
+ memory += "\n" + "User: " + reply
358
+ memory += "\n" + "Good Friend: "+ acknowledgement_followback_question
359
+ iter += 1
360
+ return "Fast", acknowledgement_followback_question, transform_text_to_speech(acknowledgement_followback_question)
361
+ if iter > 3:
362
+ prompt = CONVERSATION_EXPANDING_PROMPT.format(role = role, history=memory, rules = conversation_expanding_prompt_rules)
363
+ res = fast_thinking(image_path, prompt, QUESTION_PARSER_3, temperature)
364
+ acknowledgement_followback_question = res
365
+ memory += "\n" + "User: " + reply
366
+ memory += "\n" + "Good Friend: "+ acknowledgement_followback_question
367
+ iter += 1
368
+ return "Fast", acknowledgement_followback_question, transform_text_to_speech(acknowledgement_followback_question)
369
+ # Slow Thinking
370
+ # else:
371
+ # prompt = CONVERSATION_EXPANDING_PROMPT.format(history=memory)
372
+ # critique_prompt = CONVERSATION_EXPANDING_PROMPT_CRITIQUE.format(question="{question}", history=memory)
373
+ # res = slow_thinking(image_path, prompt, critique_prompt, temperature)
374
+ # question = res['question2']
375
+ # memory += "\n" + "User: " + reply
376
+ # memory += "\n" + "Good Friend: "+ question
377
+ # iter += 1
378
+ # return "Slow", res["question1"], res["critique"], res["question2"]
379
+
380
+ def generate_story(image_input):
381
+ global memory
382
+ global iter
383
+ global image_path
384
+ global generate_story_prompt
385
+
386
+ if iter < 4:
387
+ return "Fast", "No Solid Content to generate a Story", transform_text_to_speech("No Solid Content to generate a Story")
388
+ prompt = generate_story_prompt.format(conversation = memory)
389
+ res = get_story(image_path, prompt, 0.5)
390
+ return "Fast", res, transform_text_to_speech(res)
391
+
392
+ def clear():
393
+ global memory
394
+ global iter
395
+ global image_path
396
+
397
+ memory = ""
398
+ iter = 1
399
+ image_path = ""
400
+
401
+ return None, "", "", None
402
+
403
+
404
+
405
+ # Gradio Interface
406
+ with gr.Blocks(title = "Experimental Setup for Kitchentable.AI") as demo:
407
+ with gr.Row():
408
+ with gr.Column():
409
+ image_input = gr.Image(type="filepath", label="Upload an Image")
410
+ role = gr.Textbox(label="Role")
411
+ conversation_starter_prompt_rules = gr.Textbox(label="Conversation starter prompt rules(Generates question 1)")
412
+ conversation_starter2_prompt_rules = gr.Textbox(label="Conversation starter2 prompt rules(Generates questions 2, 3)")
413
+ conversation_expanding_prompt_rules = gr.Textbox(label="Conversation expanding prompt rules(Generates question after 3)")
414
+ temperature = gr.Slider(minimum=0, maximum=0.9999, step=0.01, label="Temperature")
415
+
416
+ with gr.Column():
417
+ thinkingType = gr.Textbox(label="Thinking Type")
418
+ question = gr.Textbox(label="Agent Output")
419
+ audio_output = gr.HTML(label="Audio Player")
420
+ reply = gr.Textbox(label="Your reply to the question")
421
+ submit_button = gr.Button("Submit Reply", elem_id="Submit")
422
+ Generate_story = gr.Button("Generate Story", elem_id="Submit")
423
+ reset_setup = gr.Button("Reset Setup", elem_id="Submit")
424
+ # critique = gr.Textbox(label="Agent Fast Thinking question Critique")
425
+ # question2 = gr.Textbox(label="Agent Slow Thinking Question")
426
+
427
+ submit_button.click(pred, inputs=[image_input, role, conversation_starter_prompt_rules,conversation_starter2_prompt_rules, conversation_expanding_prompt_rules, temperature, reply], outputs=[thinkingType, question, audio_output])
428
+ Generate_story.click(generate_story, inputs = [image_input], outputs = [thinkingType, question, audio_output])
429
+ reset_setup.click(clear, inputs = [], outputs = [image_input, thinkingType, question, audio_output])
430
+ # Launch the interface
431
+ demo.launch(share=True)
432
+