Nomnommish commited on
Commit
8345cc5
·
verified ·
1 Parent(s): c1606bd

Update xai_client.py

Browse files
Files changed (1) hide show
  1. xai_client.py +78 -12
xai_client.py CHANGED
@@ -15,7 +15,8 @@ DEFAULT_IMAGE_MODEL = "grok-imagine-image"
15
  DEFAULT_VIDEO_MODEL = "grok-imagine-video"
16
 
17
  IMAGE_ASPECT_RATIOS = [
18
- "auto", "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "2:1", "1:2", "19.5:9", "9:19.5", "20:9", "9:20"
 
19
  ]
20
  IMAGE_RESOLUTIONS = ["1k", "2k"]
21
  VIDEO_ASPECT_RATIOS = ["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3"]
@@ -64,12 +65,14 @@ def get_space_base_url(request: gr.Request | None) -> str | None:
64
  space_host = os.getenv("SPACE_HOST")
65
  if space_host:
66
  return f"https://{space_host}".rstrip("/")
 
67
  if request is not None:
68
  headers = getattr(request, "headers", None) or {}
69
  host = headers.get("x-forwarded-host") or headers.get("host")
70
  proto = headers.get("x-forwarded-proto") or "https"
71
  if host:
72
  return f"{proto}://{host}".rstrip("/")
 
73
  return None
74
 
75
 
@@ -84,22 +87,25 @@ def list_xai_models(api_key: str):
84
  headers = auth_headers(api_key)
85
  image_models = [DEFAULT_IMAGE_MODEL]
86
  video_models = [DEFAULT_VIDEO_MODEL]
 
87
  try:
88
  r = requests.get(f"{API_BASE}/image-generation-models", headers=headers, timeout=60)
89
  if r.ok:
90
  image_models = [m["id"] for m in r.json().get("models", []) if m.get("id")] or image_models
91
  except Exception:
92
  pass
 
93
  try:
94
  r = requests.get(f"{API_BASE}/video-generation-models", headers=headers, timeout=60)
95
  if r.ok:
96
  video_models = [m["id"] for m in r.json().get("models", []) if m.get("id")] or video_models
97
  except Exception:
98
  pass
 
99
  return (
100
  gr.update(choices=image_models, value=image_models[0]),
101
  gr.update(choices=video_models, value=video_models[0]),
102
- f"Loaded {len(image_models)} image model(s) and {len(video_models)} video model(s)."
103
  )
104
 
105
 
@@ -111,16 +117,19 @@ def generate_t2i(api_key, model, prompt, n, aspect_ratio, resolution, progress=g
111
  "n": int(n),
112
  "response_format": "b64_json",
113
  }
 
114
  if not payload["prompt"]:
115
  raise gr.Error("Please enter a prompt.")
116
  if aspect_ratio:
117
  payload["aspect_ratio"] = aspect_ratio
118
  if resolution:
119
  payload["resolution"] = resolution
 
120
  progress(0.2, desc="Generating images...")
121
  resp = requests.post(f"{API_BASE}/images/generations", headers=headers, json=payload, timeout=300)
122
  if not resp.ok:
123
  raise gr.Error(f"xAI image generation failed:\n{safe_json_error(resp)}")
 
124
  gallery, paths = [], []
125
  for i, item in enumerate(resp.json().get("data", []), start=1):
126
  if item.get("b64_json"):
@@ -129,35 +138,47 @@ def generate_t2i(api_key, model, prompt, n, aspect_ratio, resolution, progress=g
129
  out = download_url_to_temp(item["url"], ".png")
130
  else:
131
  continue
 
132
  paths.append(out)
133
  gallery.append((out, f"Image {i}"))
 
134
  if not paths:
135
  raise gr.Error("xAI returned no images.")
 
136
  progress(1.0, desc="Done")
137
  return gallery, paths[0], paths, f"Generated {len(paths)} image(s)."
138
 
139
 
140
  def edit_like_i2i(api_key, model, prompt, input_image_path, aspect_ratio, progress=gr.Progress(track_tqdm=False)):
141
  headers = auth_headers(api_key)
 
142
  if not input_image_path:
143
  raise gr.Error("Please upload an image.")
144
  if not (prompt or "").strip():
145
  raise gr.Error("Please enter a prompt.")
 
146
  payload = {
147
  "model": model or DEFAULT_IMAGE_MODEL,
148
  "prompt": prompt.strip(),
149
- "image": {"url": file_to_data_uri(input_image_path), "type": "image_url"},
 
 
 
150
  "response_format": "b64_json",
151
  }
 
152
  if aspect_ratio:
153
  payload["aspect_ratio"] = aspect_ratio
 
154
  progress(0.2, desc="Editing image...")
155
  resp = requests.post(f"{API_BASE}/images/edits", headers=headers, json=payload, timeout=300)
156
  if not resp.ok:
157
  raise gr.Error(f"xAI image edit failed:\n{safe_json_error(resp)}")
 
158
  data = resp.json().get("data", [])
159
  if not data:
160
  raise gr.Error("xAI returned no image.")
 
161
  item = data[0]
162
  if item.get("b64_json"):
163
  out = download_bytes_to_temp(base64.b64decode(item["b64_json"]), ".png")
@@ -165,6 +186,7 @@ def edit_like_i2i(api_key, model, prompt, input_image_path, aspect_ratio, progre
165
  out = download_url_to_temp(item["url"], ".png")
166
  else:
167
  raise gr.Error("xAI returned no output image payload.")
 
168
  progress(1.0, desc="Done")
169
  return out, out, f"Completed: {Path(out).name}"
170
 
@@ -172,29 +194,50 @@ def edit_like_i2i(api_key, model, prompt, input_image_path, aspect_ratio, progre
172
  def poll_video_result(api_key, request_id, timeout_seconds, poll_interval, progress):
173
  headers = {"Authorization": f"Bearer {api_key.strip()}"}
174
  started = time.time()
 
175
  while True:
176
  if time.time() - started > timeout_seconds:
177
  raise gr.Error("Timed out waiting for xAI video generation.")
 
178
  resp = requests.get(f"{API_BASE}/videos/{request_id}", headers=headers, timeout=120)
179
  if not resp.ok:
180
  raise gr.Error(f"xAI video polling failed:\n{safe_json_error(resp)}")
 
181
  data = resp.json()
182
  status = data.get("status", "unknown")
183
  progress(None, desc=f"Video status: {status}")
 
184
  if status == "done":
185
- video_url = (data.get("video") or {}).get("url")
 
186
  if not video_url:
187
  raise gr.Error("xAI returned no video URL.")
188
- return video_url, (data.get("video") or {}).get("duration"), data
189
- if status == "expired":
190
- raise gr.Error("xAI request expired before retrieval.")
 
 
191
  time.sleep(int(poll_interval))
192
 
193
 
194
- def generate_i2v(api_key, model, prompt, uploaded_image_path, use_last_t2i_image, last_t2i_first_image, duration, aspect_ratio, resolution, timeout_seconds, poll_interval, progress=gr.Progress(track_tqdm=False)):
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  source_image = last_t2i_first_image if use_last_t2i_image and last_t2i_first_image else uploaded_image_path
196
  if not source_image:
197
  raise gr.Error("Upload an image or use the first T2I result.")
 
198
  headers = auth_headers(api_key)
199
  payload = {
200
  "model": model or DEFAULT_VIDEO_MODEL,
@@ -203,43 +246,66 @@ def generate_i2v(api_key, model, prompt, uploaded_image_path, use_last_t2i_image
203
  "duration": int(duration),
204
  "resolution": resolution,
205
  }
 
206
  if not payload["prompt"]:
207
  raise gr.Error("Please enter an I2V prompt.")
208
  if aspect_ratio:
209
  payload["aspect_ratio"] = aspect_ratio
 
210
  progress(0.2, desc="Submitting I2V...")
211
  resp = requests.post(f"{API_BASE}/videos/generations", headers=headers, json=payload, timeout=300)
212
  if not resp.ok:
213
  raise gr.Error(f"xAI I2V request failed:\n{safe_json_error(resp)}")
 
214
  request_id = resp.json().get("request_id")
215
  if not request_id:
216
  raise gr.Error("xAI did not return request_id.")
217
- video_url, actual_duration, _ = poll_video_result(api_key, request_id, int(timeout_seconds), int(poll_interval), progress)
 
 
 
218
  out = download_url_to_temp(video_url, ".mp4")
219
  return out, out, f"I2V complete. Request ID: {request_id}. Duration: {actual_duration}s"
220
 
221
 
222
- def generate_v2v(api_key, model, prompt, uploaded_video_path, timeout_seconds, poll_interval, request, progress=gr.Progress(track_tqdm=False)):
 
 
 
 
 
 
 
 
 
223
  if not uploaded_video_path:
224
  raise gr.Error("Please upload an MP4 source video.")
225
  if Path(uploaded_video_path).suffix.lower() != ".mp4":
226
  raise gr.Error("xAI V2V expects .mp4 input.")
 
227
  headers = auth_headers(api_key)
228
  public_video_url = local_file_to_public_url(uploaded_video_path, request)
 
229
  payload = {
230
  "model": model or DEFAULT_VIDEO_MODEL,
231
  "prompt": (prompt or "").strip(),
232
  "video_url": public_video_url,
233
  }
 
234
  if not payload["prompt"]:
235
  raise gr.Error("Please enter a V2V prompt.")
 
236
  progress(0.2, desc="Submitting V2V...")
237
  resp = requests.post(f"{API_BASE}/videos/generations", headers=headers, json=payload, timeout=300)
238
  if not resp.ok:
239
  raise gr.Error(f"xAI V2V request failed:\n{safe_json_error(resp)}")
 
240
  request_id = resp.json().get("request_id")
241
  if not request_id:
242
  raise gr.Error("xAI did not return request_id.")
243
- video_url, actual_duration, _ = poll_video_result(api_key, request_id, int(timeout_seconds), int(poll_interval), progress)
 
 
 
244
  out = download_url_to_temp(video_url, ".mp4")
245
- return out, out, f"V2V complete. Request ID: {request_id}. Duration: {actual_duration}s"
 
15
  DEFAULT_VIDEO_MODEL = "grok-imagine-video"
16
 
17
  IMAGE_ASPECT_RATIOS = [
18
+ "auto", "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3",
19
+ "2:1", "1:2", "19.5:9", "9:19.5", "20:9", "9:20"
20
  ]
21
  IMAGE_RESOLUTIONS = ["1k", "2k"]
22
  VIDEO_ASPECT_RATIOS = ["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3"]
 
65
  space_host = os.getenv("SPACE_HOST")
66
  if space_host:
67
  return f"https://{space_host}".rstrip("/")
68
+
69
  if request is not None:
70
  headers = getattr(request, "headers", None) or {}
71
  host = headers.get("x-forwarded-host") or headers.get("host")
72
  proto = headers.get("x-forwarded-proto") or "https"
73
  if host:
74
  return f"{proto}://{host}".rstrip("/")
75
+
76
  return None
77
 
78
 
 
87
  headers = auth_headers(api_key)
88
  image_models = [DEFAULT_IMAGE_MODEL]
89
  video_models = [DEFAULT_VIDEO_MODEL]
90
+
91
  try:
92
  r = requests.get(f"{API_BASE}/image-generation-models", headers=headers, timeout=60)
93
  if r.ok:
94
  image_models = [m["id"] for m in r.json().get("models", []) if m.get("id")] or image_models
95
  except Exception:
96
  pass
97
+
98
  try:
99
  r = requests.get(f"{API_BASE}/video-generation-models", headers=headers, timeout=60)
100
  if r.ok:
101
  video_models = [m["id"] for m in r.json().get("models", []) if m.get("id")] or video_models
102
  except Exception:
103
  pass
104
+
105
  return (
106
  gr.update(choices=image_models, value=image_models[0]),
107
  gr.update(choices=video_models, value=video_models[0]),
108
+ f"Loaded {len(image_models)} image model(s) and {len(video_models)} video model(s).",
109
  )
110
 
111
 
 
117
  "n": int(n),
118
  "response_format": "b64_json",
119
  }
120
+
121
  if not payload["prompt"]:
122
  raise gr.Error("Please enter a prompt.")
123
  if aspect_ratio:
124
  payload["aspect_ratio"] = aspect_ratio
125
  if resolution:
126
  payload["resolution"] = resolution
127
+
128
  progress(0.2, desc="Generating images...")
129
  resp = requests.post(f"{API_BASE}/images/generations", headers=headers, json=payload, timeout=300)
130
  if not resp.ok:
131
  raise gr.Error(f"xAI image generation failed:\n{safe_json_error(resp)}")
132
+
133
  gallery, paths = [], []
134
  for i, item in enumerate(resp.json().get("data", []), start=1):
135
  if item.get("b64_json"):
 
138
  out = download_url_to_temp(item["url"], ".png")
139
  else:
140
  continue
141
+
142
  paths.append(out)
143
  gallery.append((out, f"Image {i}"))
144
+
145
  if not paths:
146
  raise gr.Error("xAI returned no images.")
147
+
148
  progress(1.0, desc="Done")
149
  return gallery, paths[0], paths, f"Generated {len(paths)} image(s)."
150
 
151
 
152
  def edit_like_i2i(api_key, model, prompt, input_image_path, aspect_ratio, progress=gr.Progress(track_tqdm=False)):
153
  headers = auth_headers(api_key)
154
+
155
  if not input_image_path:
156
  raise gr.Error("Please upload an image.")
157
  if not (prompt or "").strip():
158
  raise gr.Error("Please enter a prompt.")
159
+
160
  payload = {
161
  "model": model or DEFAULT_IMAGE_MODEL,
162
  "prompt": prompt.strip(),
163
+ "image": {
164
+ "url": file_to_data_uri(input_image_path),
165
+ "type": "image_url",
166
+ },
167
  "response_format": "b64_json",
168
  }
169
+
170
  if aspect_ratio:
171
  payload["aspect_ratio"] = aspect_ratio
172
+
173
  progress(0.2, desc="Editing image...")
174
  resp = requests.post(f"{API_BASE}/images/edits", headers=headers, json=payload, timeout=300)
175
  if not resp.ok:
176
  raise gr.Error(f"xAI image edit failed:\n{safe_json_error(resp)}")
177
+
178
  data = resp.json().get("data", [])
179
  if not data:
180
  raise gr.Error("xAI returned no image.")
181
+
182
  item = data[0]
183
  if item.get("b64_json"):
184
  out = download_bytes_to_temp(base64.b64decode(item["b64_json"]), ".png")
 
186
  out = download_url_to_temp(item["url"], ".png")
187
  else:
188
  raise gr.Error("xAI returned no output image payload.")
189
+
190
  progress(1.0, desc="Done")
191
  return out, out, f"Completed: {Path(out).name}"
192
 
 
194
  def poll_video_result(api_key, request_id, timeout_seconds, poll_interval, progress):
195
  headers = {"Authorization": f"Bearer {api_key.strip()}"}
196
  started = time.time()
197
+
198
  while True:
199
  if time.time() - started > timeout_seconds:
200
  raise gr.Error("Timed out waiting for xAI video generation.")
201
+
202
  resp = requests.get(f"{API_BASE}/videos/{request_id}", headers=headers, timeout=120)
203
  if not resp.ok:
204
  raise gr.Error(f"xAI video polling failed:\n{safe_json_error(resp)}")
205
+
206
  data = resp.json()
207
  status = data.get("status", "unknown")
208
  progress(None, desc=f"Video status: {status}")
209
+
210
  if status == "done":
211
+ video = data.get("video") or {}
212
+ video_url = video.get("url")
213
  if not video_url:
214
  raise gr.Error("xAI returned no video URL.")
215
+ return video_url, video.get("duration"), data
216
+
217
+ if status in {"failed", "error", "cancelled", "expired"}:
218
+ raise gr.Error(f"xAI video job ended with status: {status}\n{json.dumps(data, indent=2)}")
219
+
220
  time.sleep(int(poll_interval))
221
 
222
 
223
+ def generate_i2v(
224
+ api_key,
225
+ model,
226
+ prompt,
227
+ uploaded_image_path,
228
+ use_last_t2i_image,
229
+ last_t2i_first_image,
230
+ duration,
231
+ aspect_ratio,
232
+ resolution,
233
+ timeout_seconds,
234
+ poll_interval,
235
+ progress=gr.Progress(track_tqdm=False),
236
+ ):
237
  source_image = last_t2i_first_image if use_last_t2i_image and last_t2i_first_image else uploaded_image_path
238
  if not source_image:
239
  raise gr.Error("Upload an image or use the first T2I result.")
240
+
241
  headers = auth_headers(api_key)
242
  payload = {
243
  "model": model or DEFAULT_VIDEO_MODEL,
 
246
  "duration": int(duration),
247
  "resolution": resolution,
248
  }
249
+
250
  if not payload["prompt"]:
251
  raise gr.Error("Please enter an I2V prompt.")
252
  if aspect_ratio:
253
  payload["aspect_ratio"] = aspect_ratio
254
+
255
  progress(0.2, desc="Submitting I2V...")
256
  resp = requests.post(f"{API_BASE}/videos/generations", headers=headers, json=payload, timeout=300)
257
  if not resp.ok:
258
  raise gr.Error(f"xAI I2V request failed:\n{safe_json_error(resp)}")
259
+
260
  request_id = resp.json().get("request_id")
261
  if not request_id:
262
  raise gr.Error("xAI did not return request_id.")
263
+
264
+ video_url, actual_duration, _ = poll_video_result(
265
+ api_key, request_id, int(timeout_seconds), int(poll_interval), progress
266
+ )
267
  out = download_url_to_temp(video_url, ".mp4")
268
  return out, out, f"I2V complete. Request ID: {request_id}. Duration: {actual_duration}s"
269
 
270
 
271
+ def generate_v2v(
272
+ api_key,
273
+ model,
274
+ prompt,
275
+ uploaded_video_path,
276
+ timeout_seconds,
277
+ poll_interval,
278
+ request: gr.Request,
279
+ progress=gr.Progress(track_tqdm=False),
280
+ ):
281
  if not uploaded_video_path:
282
  raise gr.Error("Please upload an MP4 source video.")
283
  if Path(uploaded_video_path).suffix.lower() != ".mp4":
284
  raise gr.Error("xAI V2V expects .mp4 input.")
285
+
286
  headers = auth_headers(api_key)
287
  public_video_url = local_file_to_public_url(uploaded_video_path, request)
288
+
289
  payload = {
290
  "model": model or DEFAULT_VIDEO_MODEL,
291
  "prompt": (prompt or "").strip(),
292
  "video_url": public_video_url,
293
  }
294
+
295
  if not payload["prompt"]:
296
  raise gr.Error("Please enter a V2V prompt.")
297
+
298
  progress(0.2, desc="Submitting V2V...")
299
  resp = requests.post(f"{API_BASE}/videos/generations", headers=headers, json=payload, timeout=300)
300
  if not resp.ok:
301
  raise gr.Error(f"xAI V2V request failed:\n{safe_json_error(resp)}")
302
+
303
  request_id = resp.json().get("request_id")
304
  if not request_id:
305
  raise gr.Error("xAI did not return request_id.")
306
+
307
+ video_url, actual_duration, _ = poll_video_result(
308
+ api_key, request_id, int(timeout_seconds), int(poll_interval), progress
309
+ )
310
  out = download_url_to_temp(video_url, ".mp4")
311
+ return out, out, f"V2V complete. Request ID: {request_id}. Duration: {actual_duration}s"