Spaces:
Running
Running
| import aiohttp | |
| import asyncio | |
| import base64 | |
| import re | |
| import logging | |
| from typing import Dict, Any | |
| logging.basicConfig(level=logging.INFO) | |
| # A simple in-memory cache for project IDs | |
| project_id_cache = {} | |
| async def get_project_id(api_key: str) -> str: | |
| if api_key in project_id_cache: | |
| return project_id_cache[api_key] | |
| url = f"https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.6:streamGenerateContent?key={api_key}" | |
| headers = {"Content-Type": "application/json"} | |
| data = {} | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, headers=headers, json=data) as response: | |
| error_data = await response.json() | |
| message = error_data[0]["error"]["message"] | |
| match = re.search(r"projects/(\d+)/", message) | |
| if match: | |
| project_id = match.group(1) | |
| project_id_cache[api_key] = project_id | |
| return project_id | |
| else: | |
| raise Exception("Could not extract project ID") | |
| async def start_video_generation(project_id: str, api_key: str, params: Dict[str, Any]) -> str: | |
| location = "us-central1" # Or make this a parameter | |
| model_id = params.pop("model") | |
| url = f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/{model_id}:predictLongRunning?key={api_key}" | |
| instances = [{"prompt": params.get("prompt")}] | |
| if "image" in params and params["image"]: | |
| instances[0]["image"] = { | |
| "bytesBase64Encoded": base64.b64encode(params["image"]).decode("utf-8"), | |
| "mimeType": params["image_mime_type"] | |
| } | |
| if "video" in params and params["video"]: | |
| instances[0]["video"] = { | |
| "bytesBase64Encoded": base64.b64encode(params["video"]).decode("utf-8"), | |
| "mimeType": params["video_mime_type"] | |
| } | |
| parameters = { | |
| "aspectRatio": params.get("aspectRatio"), | |
| "durationSeconds": params.get("durationSeconds"), | |
| "resolution": params.get("resolution"), | |
| "generateAudio": params.get("generateAudio"), | |
| "enhancePrompt": params.get("enhancePrompt"), | |
| "negativePrompt": params.get("negativePrompt"), | |
| "personGeneration": params.get("personGeneration"), | |
| "sampleCount": params.get("sampleCount"), | |
| "seed": params.get("seed"), | |
| "safetySetting": "block_none" | |
| } | |
| # Remove None values from parameters | |
| parameters = {k: v for k, v in parameters.items() if v is not None} | |
| payload = { | |
| "instances": instances, | |
| "parameters": parameters | |
| } | |
| logging.info(f"Sending video generation request to {url} with payload: {payload}") | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, json=payload) as response: | |
| data = await response.json() | |
| logging.info(f"Received response: {data}") | |
| response.raise_for_status() | |
| return data["name"] | |
| async def poll_video_status(project_id: str, model_id: str, operation_name: str, api_key: str) -> Dict[str, Any]: | |
| location = "us-central1" | |
| url = f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/{model_id}:fetchPredictOperation?key={api_key}" | |
| payload = {"operationName": operation_name} | |
| async with aiohttp.ClientSession() as session: | |
| while True: | |
| logging.info(f"Polling status from {url} with payload: {payload}") | |
| async with session.post(url, json=payload) as response: | |
| data = await response.json() | |
| logging.info(f"Received polling response: {data}") | |
| response.raise_for_status() | |
| if data.get("done"): | |
| if "error" in data: | |
| raise Exception(data['error']['message']) | |
| response_data = data.get("response", {}) | |
| # If videos are present, return them, even if some were filtered. | |
| if "videos" in response_data: | |
| return response_data | |
| # If no videos, but filtering reasons exist, then all were blocked. | |
| if "raiMediaFilteredReasons" in response_data: | |
| raise Exception(response_data['raiMediaFilteredReasons'][0]) | |
| return response_data | |
| await asyncio.sleep(5) # Poll every 5 seconds |