ibyteohdear commited on
Commit
4c2b737
·
verified ·
1 Parent(s): ffd9403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -28
app.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import InferenceClient
11
  import tempfile
12
  from PIL import Image
13
  import concurrent.futures
 
14
 
15
  def run_with_timeout(func, timeout=300, *args, **kwargs):
16
  """Run function with timeout"""
@@ -28,7 +29,28 @@ def pil_to_tempfile(image):
28
  image.save(tmp_path, format="PNG")
29
  return tmp_path
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  token = os.getenv("HF_TOKEN")
 
 
32
 
33
  client = InferenceClient(token=token)
34
 
@@ -48,25 +70,6 @@ text_to_image_client = InferenceClient(
48
  api_key=token
49
  )
50
 
51
- def resize_and_crop(image, target_res=(832, 480)):
52
- tw, th = target_res
53
- iw, ih = image.size
54
- scale = max(tw / iw, th / ih)
55
- nw, nh = int(iw * scale), int(ih * scale)
56
- image = image.resize((nw, nh), Image.LANCZOS)
57
- left = (nw - tw) // 2
58
- if ih > iw:
59
- top = int((nh - th) * 0.25)
60
- else:
61
- top = (nh - th) // 2
62
- right = left + tw
63
- bottom = top + th
64
- return image.crop((left, top, right, bottom))
65
-
66
- def aligned_num_frames(duration, fps=16):
67
- n = int(duration * fps)
68
- return ((n - 1) // 4) * 4 + 1
69
-
70
  image_output = None
71
  video_output = None
72
 
@@ -81,7 +84,8 @@ def video_tool(
81
  prompt: str = "high quality, detailed, sharp, cinematic",
82
  duration: float = 4,
83
  steps: int = 20,
84
- guidance: float = 3.0
 
85
  ) -> str:
86
  """
87
  Generates a video from a starting image using Wan 2.1.
@@ -97,6 +101,8 @@ def video_tool(
97
  global video_output
98
 
99
  try:
 
 
100
  FPS = 12
101
  num_frames = aligned_num_frames(duration, FPS)
102
 
@@ -113,10 +119,10 @@ def video_tool(
113
 
114
  video_bytes = run_with_timeout(generate_video, timeout=300)
115
 
 
116
  out = tempfile.mktemp(suffix=".mp4")
117
  with open(out, "wb") as f:
118
- f.write(video_bytes)
119
-
120
  video_output = out
121
  return "Video successfully generated and stored for Gradio UI."
122
 
@@ -125,7 +131,7 @@ def video_tool(
125
  return f"Video generation failed: {e}"
126
 
127
  @tool
128
- def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
129
  """
130
  Suitable for filtering through score explicit or inappropriate content in images.
131
  Args:
@@ -134,6 +140,8 @@ def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
134
  str: Highest score result.
135
  """
136
  try:
 
 
137
  tmp_path = pil_to_tempfile(nsfw_detection_input)
138
 
139
  outputs = client.image_classification(
@@ -149,13 +157,14 @@ def nsfw_detection_tool(nsfw_detection_input: Image.Image) -> str:
149
  f"Confidence: {top_result.score:.2%}"
150
  )
151
 
 
152
  return verdict
153
 
154
  except Exception as e:
155
  return f"NSFW detection failed: {e}"
156
 
157
  @tool
158
- def image_tool(prompt: str) -> str:
159
  """
160
  Generate an image from text using SD3-Medium.
161
  Args:
@@ -166,6 +175,8 @@ def image_tool(prompt: str) -> str:
166
  global image_output
167
 
168
  try:
 
 
169
  def generate_image():
170
  return text_to_image_client.text_to_image(
171
  prompt=prompt,
@@ -176,6 +187,7 @@ def image_tool(prompt: str) -> str:
176
  height=992
177
  )
178
 
 
179
  image = run_with_timeout(generate_image, timeout=300)
180
  image_output = image
181
  return "Image successfully generated and stored for Gradio UI."
@@ -186,7 +198,7 @@ def image_tool(prompt: str) -> str:
186
 
187
 
188
  @tool
189
- def search_tool(query: str) -> str:
190
  """
191
  Search the web and return the most relevant results.
192
  Args:
@@ -195,7 +207,11 @@ def search_tool(query: str) -> str:
195
  str: The search results.
196
  """
197
  try:
 
 
198
  web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0)
 
 
199
  results = web_search_tool(query)
200
  return results
201
 
@@ -252,13 +268,14 @@ def run_agent(
252
  video_prompt_param="",
253
  video_duration_param=4.0,
254
  video_steps_param=20,
255
- video_guidance_param=3.0,
 
256
  ):
257
  global image_output, video_output
258
  image_output = None
259
  video_output = None
260
 
261
- yield None, None, "Jerry is thinking... please wait"
262
 
263
  try:
264
 
@@ -285,6 +302,8 @@ def run_agent(
285
  }
286
  )
287
 
 
 
288
  yield image_output, video_output, str(response)
289
 
290
  except Exception as e:
@@ -429,7 +448,7 @@ with gr.Blocks(title="Jerry AI Assistant") as demo:
429
  ],
430
  outputs=[image_output_display, gr.Video(visible=False), agent_response]
431
  )
432
-
433
  if __name__ == "__main__":
434
  demo.launch(
435
  server_name="0.0.0.0",
 
11
  import tempfile
12
  from PIL import Image
13
  import concurrent.futures
14
+ from fastapi import FastAPI
15
 
16
  def run_with_timeout(func, timeout=300, *args, **kwargs):
17
  """Run function with timeout"""
 
29
  image.save(tmp_path, format="PNG")
30
  return tmp_path
31
 
32
+ def resize_and_crop(image, target_res=(832, 480)):
33
+ tw, th = target_res
34
+ iw, ih = image.size
35
+ scale = max(tw / iw, th / ih)
36
+ nw, nh = int(iw * scale), int(ih * scale)
37
+ image = image.resize((nw, nh), Image.LANCZOS)
38
+ left = (nw - tw) // 2
39
+ if ih > iw:
40
+ top = int((nh - th) * 0.25)
41
+ else:
42
+ top = (nh - th) // 2
43
+ right = left + tw
44
+ bottom = top + th
45
+ return image.crop((left, top, right, bottom))
46
+
47
+ def aligned_num_frames(duration, fps=16):
48
+ n = int(duration * fps)
49
+ return ((n - 1) // 4) * 4 + 1
50
+
51
  token = os.getenv("HF_TOKEN")
52
+ if not token:
53
+ raise RuntimeError("Please set HF_TOKEN environment variable")
54
 
55
  client = InferenceClient(token=token)
56
 
 
70
  api_key=token
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  image_output = None
74
  video_output = None
75
 
 
84
  prompt: str = "high quality, detailed, sharp, cinematic",
85
  duration: float = 4,
86
  steps: int = 20,
87
+ guidance: float = 3.0,
88
+ progress: gr.Progress = gr.Progress(),
89
  ) -> str:
90
  """
91
  Generates a video from a starting image using Wan 2.1.
 
101
  global video_output
102
 
103
  try:
104
+ progress(0.125, desc="Performing diffusion… this might take awhile..")
105
+
106
  FPS = 12
107
  num_frames = aligned_num_frames(duration, FPS)
108
 
 
119
 
120
  video_bytes = run_with_timeout(generate_video, timeout=300)
121
 
122
+ progress(0.90, desc="Exporting..")
123
  out = tempfile.mktemp(suffix=".mp4")
124
  with open(out, "wb") as f:
125
+ f.write(video_bytes)
 
126
  video_output = out
127
  return "Video successfully generated and stored for Gradio UI."
128
 
 
131
  return f"Video generation failed: {e}"
132
 
133
  @tool
134
+ def nsfw_detection_tool(nsfw_detection_input: Image.Image, progress: gr.Progress = gr.Progress(),) -> str:
135
  """
136
  Suitable for filtering through score explicit or inappropriate content in images.
137
  Args:
 
140
  str: Highest score result.
141
  """
142
  try:
143
+ progress(0.125, desc="Checking image..")
144
+
145
  tmp_path = pil_to_tempfile(nsfw_detection_input)
146
 
147
  outputs = client.image_classification(
 
157
  f"Confidence: {top_result.score:.2%}"
158
  )
159
 
160
+ progress(0.90, desc="Returning verdict..")
161
  return verdict
162
 
163
  except Exception as e:
164
  return f"NSFW detection failed: {e}"
165
 
166
  @tool
167
+ def image_tool(prompt: str, progress: gr.Progress = gr.Progress(),) -> str:
168
  """
169
  Generate an image from text using SD3-Medium.
170
  Args:
 
175
  global image_output
176
 
177
  try:
178
+ progress(0.125, desc="Generating image.. this might take awhile...")
179
+
180
  def generate_image():
181
  return text_to_image_client.text_to_image(
182
  prompt=prompt,
 
187
  height=992
188
  )
189
 
190
+ progress(0.90, desc="Exporting..")
191
  image = run_with_timeout(generate_image, timeout=300)
192
  image_output = image
193
  return "Image successfully generated and stored for Gradio UI."
 
198
 
199
 
200
  @tool
201
+ def search_tool(query: str, progress: gr.Progress = gr.Progress(),) -> str:
202
  """
203
  Search the web and return the most relevant results.
204
  Args:
 
207
  str: The search results.
208
  """
209
  try:
210
+ progress(0.125, desc="Searchign the web...")
211
+
212
  web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0)
213
+
214
+ progress(0.90, desc="Returning Results..")
215
  results = web_search_tool(query)
216
  return results
217
 
 
268
  video_prompt_param="",
269
  video_duration_param=4.0,
270
  video_steps_param=20,
271
+ video_guidance_param=3.0,
272
+ progress: gr.Progress = gr.Progress(),
273
  ):
274
  global image_output, video_output
275
  image_output = None
276
  video_output = None
277
 
278
+ progress(0.05, desc="Jerry is thinking ")
279
 
280
  try:
281
 
 
302
  }
303
  )
304
 
305
+ progress(1, desc="Done…")
306
+
307
  yield image_output, video_output, str(response)
308
 
309
  except Exception as e:
 
448
  ],
449
  outputs=[image_output_display, gr.Video(visible=False), agent_response]
450
  )
451
+
452
  if __name__ == "__main__":
453
  demo.launch(
454
  server_name="0.0.0.0",