Gaurav vashistha commited on
Commit
62973c2
·
1 Parent(s): 4c4dca8

Use official SDK GenerateVideosOperation for polling instead of custom proxy

Browse files
Files changed (1) hide show
  1. agent.py +31 -30
agent.py CHANGED
@@ -13,12 +13,6 @@ logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
- # --- FIX 1: Proxy Class for SDK Type Compliance (Prevents Infinite Loop) ---
17
- class GetOpRequest:
18
- def __init__(self, name):
19
- self.name = name
20
-
21
-
22
  def get_file_hash(filepath):
23
  hash_md5 = hashlib.md5()
24
  with open(filepath, "rb") as f:
@@ -47,7 +41,7 @@ def analyze_only(path_a, path_c, job_id=None):
47
  try:
48
  file_a = get_or_upload_file(client, path_a)
49
  file_c = get_or_upload_file(client, path_c)
50
-
51
  while file_a.state.name == "PROCESSING" or file_c.state.name == "PROCESSING":
52
  update_job_status(job_id, "analyzing", 20, "Google processing video...")
53
  time.sleep(2)
@@ -64,14 +58,13 @@ def analyze_only(path_a, path_c, job_id=None):
64
  }
65
  """
66
  update_job_status(job_id, "analyzing", 30, "Director drafting creative morph...")
67
-
68
  res = client.models.generate_content(
69
  model="gemini-2.0-flash",
70
  contents=[prompt, file_a, file_c],
71
  config=types.GenerateContentConfig(response_mime_type="application/json")
72
  )
73
 
74
- # --- FIX 2: Robust Response Parsing (Prevents 500 Error) ---
75
  text = res.text.strip()
76
  if text.startswith("```json"): text = text[7:]
77
  elif text.startswith("```"): text = text[3:]
@@ -81,20 +74,16 @@ def analyze_only(path_a, path_c, job_id=None):
81
  data = {}
82
  try:
83
  parsed = json.loads(text)
84
- if isinstance(parsed, list):
85
- data = parsed[0] if len(parsed) > 0 else {}
86
- elif isinstance(parsed, dict):
87
- data = parsed
88
  except json.JSONDecodeError:
89
- logger.warning(f"JSON Parse Failed. Fallback to raw text. Response: {text[:50]}...")
90
- # Do NOT return early here. We must populate default keys below.
91
  pass
92
 
93
- # Return a COMPLETE dictionary so server.py never throws a KeyError
94
  return {
95
  "analysis_a": data.get("analysis_a", "Analysis unavailable."),
96
  "analysis_c": data.get("analysis_c", "Analysis unavailable."),
97
- "prompt": data.get("visual_prompt_b", text), # Use raw text if JSON key missing
98
  "status": "success"
99
  }
100
 
@@ -113,38 +102,50 @@ def generate_only(prompt, path_a, path_c, job_id, style, audio, neg, guidance, m
113
  raise Exception("GCP_PROJECT_ID missing.")
114
  client = genai.Client(vertexai=True, project=Settings.GCP_PROJECT_ID, location=Settings.GCP_LOCATION)
115
 
116
- # Start Job
117
  op = client.models.generate_videos(
118
  model='veo-3.1-generate-preview',
119
  prompt=full_prompt,
120
  config=types.GenerateVideosConfig(number_of_videos=1)
121
  )
122
 
123
- # Fix: Create Proxy Object for Polling
124
  op_name = op.name if hasattr(op, 'name') else str(op)
125
- request_proxy = GetOpRequest(op_name)
126
- logger.info(f"Polling Job: {op_name}")
 
 
 
 
 
 
 
 
127
  start_time = time.time()
128
  while True:
129
  if time.time() - start_time > 600:
130
  raise Exception("Timeout (10m).")
131
 
132
  try:
133
- op = client.operations.get(request_proxy)
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
  logger.warning(f"Polling error: {e}")
136
- time.sleep(20)
137
- continue
138
-
139
- is_done = getattr(op, 'done', False)
140
- if is_done:
141
- logger.info("Generation Done.")
142
- break
143
 
144
  logger.info("Waiting for Veo...")
145
  time.sleep(20)
146
 
147
- # Process Result
148
  res_val = op.result
149
  result = res_val() if callable(res_val) else res_val
150
 
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
 
 
 
 
 
 
16
  def get_file_hash(filepath):
17
  hash_md5 = hashlib.md5()
18
  with open(filepath, "rb") as f:
 
41
  try:
42
  file_a = get_or_upload_file(client, path_a)
43
  file_c = get_or_upload_file(client, path_c)
44
+
45
  while file_a.state.name == "PROCESSING" or file_c.state.name == "PROCESSING":
46
  update_job_status(job_id, "analyzing", 20, "Google processing video...")
47
  time.sleep(2)
 
58
  }
59
  """
60
  update_job_status(job_id, "analyzing", 30, "Director drafting creative morph...")
61
+
62
  res = client.models.generate_content(
63
  model="gemini-2.0-flash",
64
  contents=[prompt, file_a, file_c],
65
  config=types.GenerateContentConfig(response_mime_type="application/json")
66
  )
67
 
 
68
  text = res.text.strip()
69
  if text.startswith("```json"): text = text[7:]
70
  elif text.startswith("```"): text = text[3:]
 
74
  data = {}
75
  try:
76
  parsed = json.loads(text)
77
+ if isinstance(parsed, list): data = parsed[0] if len(parsed) > 0 else {}
78
+ elif isinstance(parsed, dict): data = parsed
 
 
79
  except json.JSONDecodeError:
80
+ logger.warning(f"JSON Parse Failed. Fallback to raw text.")
 
81
  pass
82
 
 
83
  return {
84
  "analysis_a": data.get("analysis_a", "Analysis unavailable."),
85
  "analysis_c": data.get("analysis_c", "Analysis unavailable."),
86
+ "prompt": data.get("visual_prompt_b", text),
87
  "status": "success"
88
  }
89
 
 
102
  raise Exception("GCP_PROJECT_ID missing.")
103
  client = genai.Client(vertexai=True, project=Settings.GCP_PROJECT_ID, location=Settings.GCP_LOCATION)
104
 
105
+ # 1. Start Job
106
  op = client.models.generate_videos(
107
  model='veo-3.1-generate-preview',
108
  prompt=full_prompt,
109
  config=types.GenerateVideosConfig(number_of_videos=1)
110
  )
111
 
112
+ # 2. Extract ID String (Critical)
113
  op_name = op.name if hasattr(op, 'name') else str(op)
114
+ logger.info(f"Polling Job ID: {op_name}")
115
+
116
+ # 3. Create Valid SDK Object for Polling
117
+ # We must reconstruct the operation object correctly so .get() works
118
+ # Pass _api_client to ensure it has the context to refresh itself
119
+ polling_op = types.GenerateVideosOperation(
120
+ name=op_name,
121
+ _api_client=client._api_client
122
+ )
123
+
124
  start_time = time.time()
125
  while True:
126
  if time.time() - start_time > 600:
127
  raise Exception("Timeout (10m).")
128
 
129
  try:
130
+ # Refresh logic: Use the official object
131
+ refreshed_op = client.operations.get(polling_op)
132
+
133
+ # Check status
134
+ if hasattr(refreshed_op, 'done') and refreshed_op.done:
135
+ logger.info("Generation Done.")
136
+ op = refreshed_op # Update main op with final result
137
+ break
138
+
139
  except Exception as e:
140
  logger.warning(f"Polling error: {e}")
141
+ # Fallback: if object refresh fails, try direct ID string in next loop
142
+ # But sleep first to avoid spam
143
+ pass
 
 
 
 
144
 
145
  logger.info("Waiting for Veo...")
146
  time.sleep(20)
147
 
148
+ # 4. Result Extraction
149
  res_val = op.result
150
  result = res_val() if callable(res_val) else res_val
151