ibyteohdear commited on
Commit
52d433d
·
verified ·
1 Parent(s): da96fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -67
app.py CHANGED
@@ -46,6 +46,74 @@ text_to_image_client = InferenceClient(
46
  api_key=token
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @tool
50
  def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
51
  """
@@ -54,7 +122,8 @@ def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
54
  nsfw_detection_input (Image.Image): The image to check.
55
  Returns:
56
  str: Highest score result.
57
- """
 
58
  try:
59
 
60
  tmp_path = pil_to_tempfile(nsfw_detection_input)
@@ -105,7 +174,6 @@ def image_tool(prompt: str) -> str:
105
 
106
  except Exception as e:
107
  image_output = None
108
- print(f"Image generation failed: {e}")
109
  return f"Image generation failed: {e}"
110
 
111
  @tool
@@ -135,6 +203,7 @@ model = InferenceClientModel(
135
  agent = CodeAgent(
136
  model=model,
137
  tools=[
 
138
  image_tool,
139
  nsfw_detection_tool,
140
  search_tool,
@@ -150,6 +219,8 @@ agent.prompt_templates["system_prompt"] += """
150
  - search_tool(query: str) -> str
151
  - Search the web and return the most relevant results.
152
  - Used for sentiment analysis
 
 
153
  - image_tool(prompt: str) -> str
154
  - Generate an image from a text prompt, if successfull or not you will be notified by the return string.
155
  - nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str
@@ -162,110 +233,154 @@ agent.prompt_templates["system_prompt"] += """
162
  - You must use this answer in final_answer
163
  """
164
 
165
- def run_agent(query, nsfw_detection_input):
166
  global image_output
 
167
  image_output = None
 
168
 
169
  yield None, "⏳ Jerry is thinking... please wait"
170
 
171
  try:
172
  response = agent.run(
173
  query if query else "",
174
- additional_args={"nsfw_detection_input": nsfw_detection_input}
 
 
 
 
 
 
 
 
175
  )
176
 
177
- yield image_output, str(response)
178
 
179
  except Exception as e:
180
- yield None, f"��� Agent Error: {str(e)}"
 
 
 
 
 
 
181
 
182
  with gr.Blocks(title="Jerry AI Assistant") as demo:
183
  gr.Markdown("# 🤖 Jerry - Your AI Assistant")
184
-
185
  agent_response = gr.Textbox(
186
  label="Response",
187
  lines=5,
188
  interactive=False
189
  )
190
-
191
  with gr.Tab("💬 Chat"):
192
- with gr.Row():
193
- query_chat = gr.Textbox(
194
- lines=3,
195
- label="Ask me anything...",
196
- placeholder="Generate an image of a cat, analyze its sentiment, etc.",
197
- scale=4
198
- )
199
-
200
- with gr.Row():
201
- run_chat_btn = gr.Button("🚀 Run", variant="primary", scale=1)
202
- clear_chat_btn = gr.Button("🗑️ Clear", scale=0)
203
-
204
- gr.Examples(
205
- examples=[
206
- "How do i cook a curry quickly",
207
- "Analyze the sentiment: This is terrible service",
208
- "Translate this text to English 读写汉字 - 学中文",
209
- ],
210
- inputs=[query_chat],
211
- label="💡 Try these:"
212
  )
213
-
 
 
 
214
  hidden_image_chat = gr.Image(visible=False)
215
-
216
  run_chat_btn.click(
217
  fn=run_agent,
218
- inputs=[query_chat, hidden_image_chat],
 
 
 
 
 
 
 
 
 
219
  outputs=[hidden_image_chat, agent_response]
220
  )
221
-
222
- with gr.Tab("🎨 Image Tools"):
223
- with gr.Row():
224
- nsfw_detection_input = gr.Image(
225
- label="Upload for NSFW check",
226
- type="pil",
227
- height=300
228
- )
229
- image_output = gr.Image(
230
- label="Generated Image",
231
- height=300
232
- )
233
-
234
- with gr.Row():
235
- query_img = gr.Textbox(
236
- lines=2,
237
- label="Image generation prompt",
238
- placeholder="A beautiful sunset over mountains..."
239
- )
240
-
241
  with gr.Row():
242
- check_nsfw_btn = gr.Button("🔍 Check NSFW")
243
- run_img_btn = gr.Button("🎨 Generate Image", variant="primary")
244
-
245
- gr.Examples(
246
- examples=[
247
- "A cyberpunk cat with neon glowing eyes",
248
- "A serene Japanese garden with cherry blossoms",
249
- "A futuristic city with flying cars at sunset",
250
- "A magical forest with bioluminescent plants",
251
- "A steampunk robot drinking tea in a Victorian parlor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  ],
253
- inputs=[query_img],
254
- label="🎨 Try these prompts:"
255
  )
256
-
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  hidden_text_img = gr.Textbox(visible=False)
258
  hidden_image_img = gr.Image(visible=False)
259
 
260
  check_nsfw_btn.click(
261
  fn=run_agent,
262
- inputs=[hidden_text_img, nsfw_detection_input],
 
 
 
 
 
 
 
 
 
263
  outputs=[hidden_image_img, agent_response]
264
  )
265
-
266
  run_img_btn.click(
267
  fn=run_agent,
268
- inputs=[query_img, hidden_image_img],
 
 
 
 
 
 
 
 
 
269
  outputs=[image_output, agent_response]
270
  )
271
 
 
46
  api_key=token
47
  )
48
 
49
+ def resize_and_crop(image, target_res=(832, 480)):
50
+ tw, th = target_res
51
+ iw, ih = image.size
52
+
53
+ scale = max(tw / iw, th / ih)
54
+ nw, nh = int(iw * scale), int(ih * scale)
55
+
56
+ image = image.resize((nw, nh), Image.LANCZOS)
57
+
58
+ left = (nw - tw) // 2
59
+
60
+ if ih > iw:
61
+ top = int((nh - th) * 0.25)
62
+ else:
63
+ top = (nh - th) // 2
64
+
65
+ right = left + tw
66
+ bottom = top + th
67
+
68
+ return image.crop((left, top, right, bottom))
69
+
70
+ def aligned_num_frames(duration, fps=16):
71
+ n = int(duration * fps)
72
+ return ((n - 1) // 4) * 4 + 1
73
+
74
+ video_output = None
75
+
76
+ @tool
77
+ def video_tool(video_image_input, video_prompt, video_duration, video_steps, video_guidance, video_randomize) -> str:
78
+ """
79
+ Generates a video from a starting image and a text prompt using Wan 2.1 via fal-ai.
80
+ Args:
81
+ video_image_input (Image.Image): The source image to be animated.
82
+ video_prompt (str): A text description of the motion or scene to generate.
83
+ video_duration (float): Length of the video in seconds.
84
+ video_steps (int): The number of diffusion inference steps (higher is better quality).
85
+ video_guidance (float): Classifier-free guidance scale for prompt adherence.
86
+ video_randomize (bool): Whether to use a random seed for varied results.
87
+
88
+ Returns:
89
+ str: A confirmation message.
90
+ """
91
+ try:
92
+ FPS = 16
93
+ num_frames = aligned_num_frames(video_duration, FPS)
94
+ seed = random.randint(0, 1_000_000_000) if video_randomize else 42
95
+
96
+ video_bytes = client.image_to_video(
97
+ image=video_image_input.resize((832, 480)),
98
+ prompt=video_prompt,
99
+ negative_prompt="low quality, deformed",
100
+ num_frames=num_frames,
101
+ num_inference_steps=int(video_steps),
102
+ seed=seed,
103
+ guidance_scale=float(video_guidance),
104
+ )
105
+
106
+ out = tempfile.mktemp(suffix=".mp4")
107
+ with open(out, "wb") as f:
108
+ f.write(video_bytes)
109
+
110
+ gc.collect()
111
+ video_output = out
112
+ return "Video successfully generated and stored for Gradio UI."
113
+ except Exception as e:
114
+ video_output = None
115
+ return f"Video generation failed: {e}"
116
+
117
  @tool
118
  def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
119
  """
 
122
  nsfw_detection_input (Image.Image): The image to check.
123
  Returns:
124
  str: Highest score result.
125
+ """
126
+ global video_output
127
  try:
128
 
129
  tmp_path = pil_to_tempfile(nsfw_detection_input)
 
174
 
175
  except Exception as e:
176
  image_output = None
 
177
  return f"Image generation failed: {e}"
178
 
179
  @tool
 
203
  agent = CodeAgent(
204
  model=model,
205
  tools=[
206
+ video_tool,
207
  image_tool,
208
  nsfw_detection_tool,
209
  search_tool,
 
219
  - search_tool(query: str) -> str
220
  - Search the web and return the most relevant results.
221
  - Used for sentiment analysis
222
+ - video_tool(video_image_input: Image.Imagem, video_prompt: str, video_duration: float, video_steps: int, video_guidance: float, video_randomize: bool) -> str
223
+ - Generate a video from a text prompt and an image input, if successfull or not you will be notified by the return string.
224
  - image_tool(prompt: str) -> str
225
  - Generate an image from a text prompt, if successfull or not you will be notified by the return string.
226
  - nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str
 
233
  - You must use this answer in final_answer
234
  """
235
 
236
+ def run_agent(query, nsfw_detection_input, video_image_input, video_prompt, video_duration, video_steps, video_guidance, video_randomize):
237
  global image_output
238
+ global video_output
239
  image_output = None
240
+ video_output = None
241
 
242
  yield None, "⏳ Jerry is thinking... please wait"
243
 
244
  try:
245
  response = agent.run(
246
  query if query else "",
247
+ additional_args={
248
+ "nsfw_detection_input": nsfw_detection_input,
249
+ "video_image_input": video_image_input,
250
+ "video_prompt": video_prompt,
251
+ "video_duration": video_duration,
252
+ "video_steps": video_steps,
253
+ "video_guidance": video_guidance,
254
+ "video_randomize": video_randomize,
255
+ }
256
  )
257
 
258
+ yield image_output, video_output, str(response)
259
 
260
  except Exception as e:
261
+ yield None, None, f" Agent Error: {str(e)}"
262
+
263
+
264
+ hidden_none_img = gr.Image(visible=False)
265
+ hidden_none_txt = gr.Textbox(visible=False)
266
+ hidden_none_float = gr.Number(visible=False, value=0)
267
+ hidden_none_bool = gr.Checkbox(visible=False, value=False)
268
 
269
  with gr.Blocks(title="Jerry AI Assistant") as demo:
270
  gr.Markdown("# 🤖 Jerry - Your AI Assistant")
271
+
272
  agent_response = gr.Textbox(
273
  label="Response",
274
  lines=5,
275
  interactive=False
276
  )
277
+
278
  with gr.Tab("💬 Chat"):
279
+ query_chat = gr.Textbox(
280
+ lines=3,
281
+ label="Ask me anything...",
282
+ placeholder="Generate an image of a cat, analyze sentiment, etc."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
284
+
285
+ run_chat_btn = gr.Button("🚀 Run", variant="primary")
286
+ clear_chat_btn = gr.Button("🗑️ Clear")
287
+
288
  hidden_image_chat = gr.Image(visible=False)
289
+
290
  run_chat_btn.click(
291
  fn=run_agent,
292
+ inputs=[
293
+ query_chat,
294
+ hidden_none_img,
295
+ hidden_none_img,
296
+ hidden_none_txt,
297
+ hidden_none_float,
298
+ hidden_none_float,
299
+ hidden_none_float,
300
+ hidden_none_bool,
301
+ ],
302
  outputs=[hidden_image_chat, agent_response]
303
  )
304
+
305
+ clear_chat_btn.click(
306
+ lambda: ("", ""),
307
+ outputs=[query_chat, agent_response]
308
+ )
309
+
310
+ with gr.Tab("🎨 Video Tools"):
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  with gr.Row():
312
+ with gr.Column():
313
+ video_image_input = gr.Image(type="pil", label="Input Image")
314
+ video_prompt = gr.Textbox(lines=3, label="Prompt")
315
+
316
+ video_duration = gr.Slider(1, 4, value=4, step=0.1, label="Duration (s)")
317
+ video_steps = gr.Slider(4, 20, value=20, step=1, label="Steps")
318
+ video_guidance = gr.Slider(1.0, 6.0, value=3.0, step=0.1, label="Guidance")
319
+ video_randomize = gr.Checkbox(value=True, label="Randomize Seed")
320
+
321
+ gen_btn = gr.Button("🎬 Generate Video", variant="primary")
322
+
323
+ with gr.Column():
324
+ output_vid = gr.Video(label="Generated Video")
325
+
326
+ gen_btn.click(
327
+ fn=run_agent,
328
+ inputs=[
329
+ hidden_none_txt,
330
+ hidden_none_img,
331
+ video_image_input,
332
+ video_prompt,
333
+ video_duration,
334
+ video_steps,
335
+ video_guidance,
336
+ video_randomize,
337
  ],
338
+ outputs=[hidden_none_img, output_vid, agent_response]
 
339
  )
340
+
341
+ with gr.Tab("🎨 Image Tools"):
342
+ nsfw_detection_input = gr.Image(type="pil", label="Upload for NSFW Check")
343
+ image_output = gr.Image(label="Generated Image")
344
+
345
+ query_img = gr.Textbox(
346
+ lines=2,
347
+ label="Image generation prompt",
348
+ placeholder="A cyberpunk cat with neon eyes..."
349
+ )
350
+
351
+ check_nsfw_btn = gr.Button("🔍 Check NSFW")
352
+ run_img_btn = gr.Button("🎨 Generate Image", variant="primary")
353
+
354
  hidden_text_img = gr.Textbox(visible=False)
355
  hidden_image_img = gr.Image(visible=False)
356
 
357
  check_nsfw_btn.click(
358
  fn=run_agent,
359
+ inputs=[
360
+ hidden_text_img,
361
+ nsfw_detection_input,
362
+ hidden_none_img,
363
+ hidden_none_txt,
364
+ hidden_none_float,
365
+ hidden_none_float,
366
+ hidden_none_float,
367
+ hidden_none_bool,
368
+ ],
369
  outputs=[hidden_image_img, agent_response]
370
  )
371
+
372
  run_img_btn.click(
373
  fn=run_agent,
374
+ inputs=[
375
+ query_img,
376
+ hidden_none_img,
377
+ hidden_none_img,
378
+ hidden_none_txt,
379
+ hidden_none_float,
380
+ hidden_none_float,
381
+ hidden_none_float,
382
+ hidden_none_bool,
383
+ ],
384
  outputs=[image_output, agent_response]
385
  )
386