express / backend /services /vertex_service.py
Raven10492's picture
Upload 7 files
7dd9c94 verified
import aiohttp
import asyncio
import base64
import re
import logging
from typing import Dict, Any
logging.basicConfig(level=logging.INFO)
# A simple in-memory cache for project IDs
project_id_cache = {}
async def get_project_id(api_key: str) -> str:
if api_key in project_id_cache:
return project_id_cache[api_key]
url = f"https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.6:streamGenerateContent?key={api_key}"
headers = {"Content-Type": "application/json"}
data = {}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
error_data = await response.json()
message = error_data[0]["error"]["message"]
match = re.search(r"projects/(\d+)/", message)
if match:
project_id = match.group(1)
project_id_cache[api_key] = project_id
return project_id
else:
raise Exception("Could not extract project ID")
async def start_video_generation(project_id: str, api_key: str, params: Dict[str, Any]) -> str:
location = "us-central1" # Or make this a parameter
model_id = params.pop("model")
url = f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/{model_id}:predictLongRunning?key={api_key}"
instances = [{"prompt": params.get("prompt")}]
if "image" in params and params["image"]:
instances[0]["image"] = {
"bytesBase64Encoded": base64.b64encode(params["image"]).decode("utf-8"),
"mimeType": params["image_mime_type"]
}
if "video" in params and params["video"]:
instances[0]["video"] = {
"bytesBase64Encoded": base64.b64encode(params["video"]).decode("utf-8"),
"mimeType": params["video_mime_type"]
}
parameters = {
"aspectRatio": params.get("aspectRatio"),
"durationSeconds": params.get("durationSeconds"),
"resolution": params.get("resolution"),
"generateAudio": params.get("generateAudio"),
"enhancePrompt": params.get("enhancePrompt"),
"negativePrompt": params.get("negativePrompt"),
"personGeneration": params.get("personGeneration"),
"sampleCount": params.get("sampleCount"),
"seed": params.get("seed"),
"safetySetting": "block_none"
}
# Remove None values from parameters
parameters = {k: v for k, v in parameters.items() if v is not None}
payload = {
"instances": instances,
"parameters": parameters
}
logging.info(f"Sending video generation request to {url} with payload: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
data = await response.json()
logging.info(f"Received response: {data}")
response.raise_for_status()
return data["name"]
async def poll_video_status(project_id: str, model_id: str, operation_name: str, api_key: str) -> Dict[str, Any]:
location = "us-central1"
url = f"https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{location}/publishers/google/models/{model_id}:fetchPredictOperation?key={api_key}"
payload = {"operationName": operation_name}
async with aiohttp.ClientSession() as session:
while True:
logging.info(f"Polling status from {url} with payload: {payload}")
async with session.post(url, json=payload) as response:
data = await response.json()
logging.info(f"Received polling response: {data}")
response.raise_for_status()
if data.get("done"):
if "error" in data:
raise Exception(data['error']['message'])
response_data = data.get("response", {})
# If videos are present, return them, even if some were filtered.
if "videos" in response_data:
return response_data
# If no videos, but filtering reasons exist, then all were blocked.
if "raiMediaFilteredReasons" in response_data:
raise Exception(response_data['raiMediaFilteredReasons'][0])
return response_data
await asyncio.sleep(5) # Poll every 5 seconds