Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -46,6 +46,74 @@ text_to_image_client = InferenceClient(
|
|
| 46 |
api_key=token
|
| 47 |
)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
@tool
|
| 50 |
def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
|
| 51 |
"""
|
|
@@ -54,7 +122,8 @@ def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
|
|
| 54 |
nsfw_detection_input (Image.Image): The image to check.
|
| 55 |
Returns:
|
| 56 |
str: Highest score result.
|
| 57 |
-
"""
|
|
|
|
| 58 |
try:
|
| 59 |
|
| 60 |
tmp_path = pil_to_tempfile(nsfw_detection_input)
|
|
@@ -105,7 +174,6 @@ def image_tool(prompt: str) -> str:
|
|
| 105 |
|
| 106 |
except Exception as e:
|
| 107 |
image_output = None
|
| 108 |
-
print(f"Image generation failed: {e}")
|
| 109 |
return f"Image generation failed: {e}"
|
| 110 |
|
| 111 |
@tool
|
|
@@ -135,6 +203,7 @@ model = InferenceClientModel(
|
|
| 135 |
agent = CodeAgent(
|
| 136 |
model=model,
|
| 137 |
tools=[
|
|
|
|
| 138 |
image_tool,
|
| 139 |
nsfw_detection_tool,
|
| 140 |
search_tool,
|
|
@@ -150,6 +219,8 @@ agent.prompt_templates["system_prompt"] += """
|
|
| 150 |
- search_tool(query: str) -> str
|
| 151 |
- Search the web and return the most relevant results.
|
| 152 |
- Used for sentiment analysis
|
|
|
|
|
|
|
| 153 |
- image_tool(prompt: str) -> str
|
| 154 |
- Generate an image from a text prompt, if successfull or not you will be notified by the return string.
|
| 155 |
- nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str
|
|
@@ -162,110 +233,154 @@ agent.prompt_templates["system_prompt"] += """
|
|
| 162 |
- You must use this answer in final_answer
|
| 163 |
"""
|
| 164 |
|
| 165 |
-
def run_agent(query, nsfw_detection_input):
|
| 166 |
global image_output
|
|
|
|
| 167 |
image_output = None
|
|
|
|
| 168 |
|
| 169 |
yield None, "⏳ Jerry is thinking... please wait"
|
| 170 |
|
| 171 |
try:
|
| 172 |
response = agent.run(
|
| 173 |
query if query else "",
|
| 174 |
-
additional_args={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
|
| 177 |
-
yield image_output, str(response)
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
-
yield None, f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
with gr.Blocks(title="Jerry AI Assistant") as demo:
|
| 183 |
gr.Markdown("# 🤖 Jerry - Your AI Assistant")
|
| 184 |
-
|
| 185 |
agent_response = gr.Textbox(
|
| 186 |
label="Response",
|
| 187 |
lines=5,
|
| 188 |
interactive=False
|
| 189 |
)
|
| 190 |
-
|
| 191 |
with gr.Tab("💬 Chat"):
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
placeholder="Generate an image of a cat, analyze its sentiment, etc.",
|
| 197 |
-
scale=4
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
with gr.Row():
|
| 201 |
-
run_chat_btn = gr.Button("🚀 Run", variant="primary", scale=1)
|
| 202 |
-
clear_chat_btn = gr.Button("🗑️ Clear", scale=0)
|
| 203 |
-
|
| 204 |
-
gr.Examples(
|
| 205 |
-
examples=[
|
| 206 |
-
"How do i cook a curry quickly",
|
| 207 |
-
"Analyze the sentiment: This is terrible service",
|
| 208 |
-
"Translate this text to English 读写汉字 - 学中文",
|
| 209 |
-
],
|
| 210 |
-
inputs=[query_chat],
|
| 211 |
-
label="💡 Try these:"
|
| 212 |
)
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
| 214 |
hidden_image_chat = gr.Image(visible=False)
|
| 215 |
-
|
| 216 |
run_chat_btn.click(
|
| 217 |
fn=run_agent,
|
| 218 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
outputs=[hidden_image_chat, agent_response]
|
| 220 |
)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
)
|
| 229 |
-
image_output = gr.Image(
|
| 230 |
-
label="Generated Image",
|
| 231 |
-
height=300
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
with gr.Row():
|
| 235 |
-
query_img = gr.Textbox(
|
| 236 |
-
lines=2,
|
| 237 |
-
label="Image generation prompt",
|
| 238 |
-
placeholder="A beautiful sunset over mountains..."
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
with gr.Row():
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
],
|
| 253 |
-
|
| 254 |
-
label="🎨 Try these prompts:"
|
| 255 |
)
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
hidden_text_img = gr.Textbox(visible=False)
|
| 258 |
hidden_image_img = gr.Image(visible=False)
|
| 259 |
|
| 260 |
check_nsfw_btn.click(
|
| 261 |
fn=run_agent,
|
| 262 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
outputs=[hidden_image_img, agent_response]
|
| 264 |
)
|
| 265 |
-
|
| 266 |
run_img_btn.click(
|
| 267 |
fn=run_agent,
|
| 268 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
outputs=[image_output, agent_response]
|
| 270 |
)
|
| 271 |
|
|
|
|
| 46 |
api_key=token
|
| 47 |
)
|
| 48 |
|
| 49 |
+
def resize_and_crop(image, target_res=(832, 480)):
|
| 50 |
+
tw, th = target_res
|
| 51 |
+
iw, ih = image.size
|
| 52 |
+
|
| 53 |
+
scale = max(tw / iw, th / ih)
|
| 54 |
+
nw, nh = int(iw * scale), int(ih * scale)
|
| 55 |
+
|
| 56 |
+
image = image.resize((nw, nh), Image.LANCZOS)
|
| 57 |
+
|
| 58 |
+
left = (nw - tw) // 2
|
| 59 |
+
|
| 60 |
+
if ih > iw:
|
| 61 |
+
top = int((nh - th) * 0.25)
|
| 62 |
+
else:
|
| 63 |
+
top = (nh - th) // 2
|
| 64 |
+
|
| 65 |
+
right = left + tw
|
| 66 |
+
bottom = top + th
|
| 67 |
+
|
| 68 |
+
return image.crop((left, top, right, bottom))
|
| 69 |
+
|
| 70 |
+
def aligned_num_frames(duration, fps=16):
|
| 71 |
+
n = int(duration * fps)
|
| 72 |
+
return ((n - 1) // 4) * 4 + 1
|
| 73 |
+
|
| 74 |
+
video_output = None
|
| 75 |
+
|
| 76 |
+
@tool
|
| 77 |
+
def video_tool(video_image_input, video_prompt, video_duration, video_steps, video_guidance, video_randomize) -> str:
|
| 78 |
+
"""
|
| 79 |
+
Generates a video from a starting image and a text prompt using Wan 2.1 via fal-ai.
|
| 80 |
+
Args:
|
| 81 |
+
video_image_input (Image.Image): The source image to be animated.
|
| 82 |
+
video_prompt (str): A text description of the motion or scene to generate.
|
| 83 |
+
video_duration (float): Length of the video in seconds.
|
| 84 |
+
video_steps (int): The number of diffusion inference steps (higher is better quality).
|
| 85 |
+
video_guidance (float): Classifier-free guidance scale for prompt adherence.
|
| 86 |
+
video_randomize (bool): Whether to use a random seed for varied results.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
str: A confirmation message.
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
FPS = 16
|
| 93 |
+
num_frames = aligned_num_frames(video_duration, FPS)
|
| 94 |
+
seed = random.randint(0, 1_000_000_000) if video_randomize else 42
|
| 95 |
+
|
| 96 |
+
video_bytes = client.image_to_video(
|
| 97 |
+
image=video_image_input.resize((832, 480)),
|
| 98 |
+
prompt=video_prompt,
|
| 99 |
+
negative_prompt="low quality, deformed",
|
| 100 |
+
num_frames=num_frames,
|
| 101 |
+
num_inference_steps=int(video_steps),
|
| 102 |
+
seed=seed,
|
| 103 |
+
guidance_scale=float(video_guidance),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
out = tempfile.mktemp(suffix=".mp4")
|
| 107 |
+
with open(out, "wb") as f:
|
| 108 |
+
f.write(video_bytes)
|
| 109 |
+
|
| 110 |
+
gc.collect()
|
| 111 |
+
video_output = out
|
| 112 |
+
return "Video successfully generated and stored for Gradio UI."
|
| 113 |
+
except Exception as e:
|
| 114 |
+
video_output = None
|
| 115 |
+
return f"Video generation failed: {e}"
|
| 116 |
+
|
| 117 |
@tool
|
| 118 |
def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
|
| 119 |
"""
|
|
|
|
| 122 |
nsfw_detection_input (Image.Image): The image to check.
|
| 123 |
Returns:
|
| 124 |
str: Highest score result.
|
| 125 |
+
"""
|
| 126 |
+
global video_output
|
| 127 |
try:
|
| 128 |
|
| 129 |
tmp_path = pil_to_tempfile(nsfw_detection_input)
|
|
|
|
| 174 |
|
| 175 |
except Exception as e:
|
| 176 |
image_output = None
|
|
|
|
| 177 |
return f"Image generation failed: {e}"
|
| 178 |
|
| 179 |
@tool
|
|
|
|
| 203 |
agent = CodeAgent(
|
| 204 |
model=model,
|
| 205 |
tools=[
|
| 206 |
+
video_tool,
|
| 207 |
image_tool,
|
| 208 |
nsfw_detection_tool,
|
| 209 |
search_tool,
|
|
|
|
| 219 |
- search_tool(query: str) -> str
|
| 220 |
- Search the web and return the most relevant results.
|
| 221 |
- Used for sentiment analysis
|
| 222 |
+
- video_tool(video_image_input: Image.Imagem, video_prompt: str, video_duration: float, video_steps: int, video_guidance: float, video_randomize: bool) -> str
|
| 223 |
+
- Generate a video from a text prompt and an image input, if successfull or not you will be notified by the return string.
|
| 224 |
- image_tool(prompt: str) -> str
|
| 225 |
- Generate an image from a text prompt, if successfull or not you will be notified by the return string.
|
| 226 |
- nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str
|
|
|
|
| 233 |
- You must use this answer in final_answer
|
| 234 |
"""
|
| 235 |
|
| 236 |
+
def run_agent(query, nsfw_detection_input, video_image_input, video_prompt, video_duration, video_steps, video_guidance, video_randomize):
|
| 237 |
global image_output
|
| 238 |
+
global video_output
|
| 239 |
image_output = None
|
| 240 |
+
video_output = None
|
| 241 |
|
| 242 |
yield None, "⏳ Jerry is thinking... please wait"
|
| 243 |
|
| 244 |
try:
|
| 245 |
response = agent.run(
|
| 246 |
query if query else "",
|
| 247 |
+
additional_args={
|
| 248 |
+
"nsfw_detection_input": nsfw_detection_input,
|
| 249 |
+
"video_image_input": video_image_input,
|
| 250 |
+
"video_prompt": video_prompt,
|
| 251 |
+
"video_duration": video_duration,
|
| 252 |
+
"video_steps": video_steps,
|
| 253 |
+
"video_guidance": video_guidance,
|
| 254 |
+
"video_randomize": video_randomize,
|
| 255 |
+
}
|
| 256 |
)
|
| 257 |
|
| 258 |
+
yield image_output, video_output, str(response)
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
+
yield None, None, f"❌ Agent Error: {str(e)}"
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
hidden_none_img = gr.Image(visible=False)
|
| 265 |
+
hidden_none_txt = gr.Textbox(visible=False)
|
| 266 |
+
hidden_none_float = gr.Number(visible=False, value=0)
|
| 267 |
+
hidden_none_bool = gr.Checkbox(visible=False, value=False)
|
| 268 |
|
| 269 |
with gr.Blocks(title="Jerry AI Assistant") as demo:
|
| 270 |
gr.Markdown("# 🤖 Jerry - Your AI Assistant")
|
| 271 |
+
|
| 272 |
agent_response = gr.Textbox(
|
| 273 |
label="Response",
|
| 274 |
lines=5,
|
| 275 |
interactive=False
|
| 276 |
)
|
| 277 |
+
|
| 278 |
with gr.Tab("💬 Chat"):
|
| 279 |
+
query_chat = gr.Textbox(
|
| 280 |
+
lines=3,
|
| 281 |
+
label="Ask me anything...",
|
| 282 |
+
placeholder="Generate an image of a cat, analyze sentiment, etc."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
)
|
| 284 |
+
|
| 285 |
+
run_chat_btn = gr.Button("🚀 Run", variant="primary")
|
| 286 |
+
clear_chat_btn = gr.Button("🗑️ Clear")
|
| 287 |
+
|
| 288 |
hidden_image_chat = gr.Image(visible=False)
|
| 289 |
+
|
| 290 |
run_chat_btn.click(
|
| 291 |
fn=run_agent,
|
| 292 |
+
inputs=[
|
| 293 |
+
query_chat,
|
| 294 |
+
hidden_none_img,
|
| 295 |
+
hidden_none_img,
|
| 296 |
+
hidden_none_txt,
|
| 297 |
+
hidden_none_float,
|
| 298 |
+
hidden_none_float,
|
| 299 |
+
hidden_none_float,
|
| 300 |
+
hidden_none_bool,
|
| 301 |
+
],
|
| 302 |
outputs=[hidden_image_chat, agent_response]
|
| 303 |
)
|
| 304 |
+
|
| 305 |
+
clear_chat_btn.click(
|
| 306 |
+
lambda: ("", ""),
|
| 307 |
+
outputs=[query_chat, agent_response]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
with gr.Tab("🎨 Video Tools"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
with gr.Row():
|
| 312 |
+
with gr.Column():
|
| 313 |
+
video_image_input = gr.Image(type="pil", label="Input Image")
|
| 314 |
+
video_prompt = gr.Textbox(lines=3, label="Prompt")
|
| 315 |
+
|
| 316 |
+
video_duration = gr.Slider(1, 4, value=4, step=0.1, label="Duration (s)")
|
| 317 |
+
video_steps = gr.Slider(4, 20, value=20, step=1, label="Steps")
|
| 318 |
+
video_guidance = gr.Slider(1.0, 6.0, value=3.0, step=0.1, label="Guidance")
|
| 319 |
+
video_randomize = gr.Checkbox(value=True, label="Randomize Seed")
|
| 320 |
+
|
| 321 |
+
gen_btn = gr.Button("🎬 Generate Video", variant="primary")
|
| 322 |
+
|
| 323 |
+
with gr.Column():
|
| 324 |
+
output_vid = gr.Video(label="Generated Video")
|
| 325 |
+
|
| 326 |
+
gen_btn.click(
|
| 327 |
+
fn=run_agent,
|
| 328 |
+
inputs=[
|
| 329 |
+
hidden_none_txt,
|
| 330 |
+
hidden_none_img,
|
| 331 |
+
video_image_input,
|
| 332 |
+
video_prompt,
|
| 333 |
+
video_duration,
|
| 334 |
+
video_steps,
|
| 335 |
+
video_guidance,
|
| 336 |
+
video_randomize,
|
| 337 |
],
|
| 338 |
+
outputs=[hidden_none_img, output_vid, agent_response]
|
|
|
|
| 339 |
)
|
| 340 |
+
|
| 341 |
+
with gr.Tab("🎨 Image Tools"):
|
| 342 |
+
nsfw_detection_input = gr.Image(type="pil", label="Upload for NSFW Check")
|
| 343 |
+
image_output = gr.Image(label="Generated Image")
|
| 344 |
+
|
| 345 |
+
query_img = gr.Textbox(
|
| 346 |
+
lines=2,
|
| 347 |
+
label="Image generation prompt",
|
| 348 |
+
placeholder="A cyberpunk cat with neon eyes..."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
check_nsfw_btn = gr.Button("🔍 Check NSFW")
|
| 352 |
+
run_img_btn = gr.Button("🎨 Generate Image", variant="primary")
|
| 353 |
+
|
| 354 |
hidden_text_img = gr.Textbox(visible=False)
|
| 355 |
hidden_image_img = gr.Image(visible=False)
|
| 356 |
|
| 357 |
check_nsfw_btn.click(
|
| 358 |
fn=run_agent,
|
| 359 |
+
inputs=[
|
| 360 |
+
hidden_text_img,
|
| 361 |
+
nsfw_detection_input,
|
| 362 |
+
hidden_none_img,
|
| 363 |
+
hidden_none_txt,
|
| 364 |
+
hidden_none_float,
|
| 365 |
+
hidden_none_float,
|
| 366 |
+
hidden_none_float,
|
| 367 |
+
hidden_none_bool,
|
| 368 |
+
],
|
| 369 |
outputs=[hidden_image_img, agent_response]
|
| 370 |
)
|
| 371 |
+
|
| 372 |
run_img_btn.click(
|
| 373 |
fn=run_agent,
|
| 374 |
+
inputs=[
|
| 375 |
+
query_img,
|
| 376 |
+
hidden_none_img,
|
| 377 |
+
hidden_none_img,
|
| 378 |
+
hidden_none_txt,
|
| 379 |
+
hidden_none_float,
|
| 380 |
+
hidden_none_float,
|
| 381 |
+
hidden_none_float,
|
| 382 |
+
hidden_none_bool,
|
| 383 |
+
],
|
| 384 |
outputs=[image_output, agent_response]
|
| 385 |
)
|
| 386 |
|