Gaurav vashistha commited on
Commit
0c79677
·
1 Parent(s): 8e0ba30

Fix infinite loop: implement GetOpRequest proxy for Veo polling

Browse files
Files changed (1) hide show
  1. agent.py +66 -117
agent.py CHANGED
@@ -13,6 +13,12 @@ logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
 
 
 
 
 
 
 
16
  def get_file_hash(filepath):
17
  hash_md5 = hashlib.md5()
18
  with open(filepath, "rb") as f:
@@ -30,153 +36,96 @@ def get_or_upload_file(client, filepath):
30
  return f
31
  except Exception:
32
  pass
33
- logger.info(f"⬆️ Uploading new file: {file_hash}")
34
  return client.files.upload(file=filepath, config={'display_name': file_hash})
35
 
36
 
37
  def analyze_only(path_a, path_c, job_id=None):
38
- update_job_status(job_id, "analyzing", 10, "Director checking file cache...")
39
  client = genai.Client(api_key=Settings.GOOGLE_API_KEY)
40
-
41
  try:
42
  file_a = get_or_upload_file(client, path_a)
43
  file_c = get_or_upload_file(client, path_c)
44
-
45
  while file_a.state.name == "PROCESSING" or file_c.state.name == "PROCESSING":
46
- update_job_status(job_id, "analyzing", 20, "Google processing video...")
47
  time.sleep(2)
48
  file_a = client.files.get(name=file_a.name)
49
  file_c = client.files.get(name=file_c.name)
50
 
51
- prompt = "You are a VFX Director. Analyze Video A and Video C. Return JSON with keys: analysis_a, analysis_c, visual_prompt_b."
52
-
53
  res = client.models.generate_content(
54
  model="gemini-2.0-flash",
55
- contents=[prompt, file_a, file_c],
56
  config=types.GenerateContentConfig(response_mime_type="application/json")
57
  )
58
-
59
  text = res.text.strip()
60
- if text.startswith("```json"): text = text[7:]
61
- elif text.startswith("```"): text = text[3:]
62
- if text.endswith("```"): text = text[:-3]
63
-
64
- try: data = json.loads(text.strip())
65
- except: data = {}
66
- if isinstance(data, list): data = data[0] if len(data) > 0 else {}
67
- return {
68
- "analysis_a": data.get("analysis_a", ""),
69
- "analysis_c": data.get("analysis_c", ""),
70
- "prompt": data.get("visual_prompt_b", text),
71
- "status": "success"
72
- }
73
  except Exception as e:
74
- logger.error(f"Analysis failed: {e}")
75
  return {"detail": str(e), "status": "error"}
76
 
77
 
78
  def generate_only(prompt, path_a, path_c, job_id, style, audio, neg, guidance, motion):
79
- update_job_status(job_id, "generating", 50, "Production started (Veo 3.1)...")
80
  full_prompt = f"{style} style. {prompt} Soundtrack: {audio}"
81
- if neg:
82
- full_prompt += f" --no {neg}"
83
-
84
- job_failed = False
85
  try:
86
- if Settings.GCP_PROJECT_ID:
87
- client = genai.Client(vertexai=True, project=Settings.GCP_PROJECT_ID, location=Settings.GCP_LOCATION)
88
-
89
- # 1. Start Job
90
- op = client.models.generate_videos(
91
- model='veo-3.1-generate-preview',
92
- prompt=full_prompt,
93
- config=types.GenerateVideosConfig(number_of_videos=1)
94
- )
 
 
 
 
 
 
 
 
 
 
95
 
96
- # 2. Extract ID String
97
- op_name = str(op)
98
- if hasattr(op, 'name'): op_name = op.name
99
- elif isinstance(op, dict) and 'name' in op: op_name = op['name']
 
 
 
100
 
101
- logger.info(f"Tracking ID: {op_name}")
 
 
 
 
102
 
103
- # 3. Poll for Completion
104
- start_time = time.time()
105
- while True:
106
- if time.time() - start_time > 300:
107
- raise Exception("Generation timed out.")
108
-
109
- try:
110
- # FIX: Use positional argument for get()
111
- current_op = client.operations.get(op_name)
112
- except Exception as e:
113
- logger.warning(f"Refresh failed: {e}")
114
- time.sleep(10)
115
- continue
116
-
117
- # Check Status
118
- is_done = False
119
- if hasattr(current_op, 'done'): is_done = current_op.done
120
- elif isinstance(current_op, dict): is_done = current_op.get('done')
121
-
122
- if is_done:
123
- logger.info("Job DONE.")
124
- op = current_op
125
- break
126
-
127
- logger.info("Waiting for Veo...")
128
- time.sleep(10)
129
 
130
- # 4. Get Result
131
- result = None
132
- if hasattr(op, 'result'):
133
- try:
134
- res_val = op.result
135
- if callable(res_val): result = res_val()
136
- else: result = res_val
137
- except: pass
138
- elif isinstance(op, dict):
139
- result = op.get('result')
140
 
141
- generated_videos = None
142
- if result:
143
- if hasattr(result, 'generated_videos'): generated_videos = result.generated_videos
144
- elif isinstance(result, dict): generated_videos = result.get('generated_videos')
145
 
146
- if generated_videos:
147
- vid = generated_videos[0]
148
- bridge_path = None
149
-
150
- uri = getattr(vid.video, 'uri', None) if hasattr(vid, 'video') else vid.get('video', {}).get('uri')
151
- video_bytes = getattr(vid.video, 'video_bytes', None) if hasattr(vid, 'video') else vid.get('video', {}).get('video_bytes')
152
-
153
- if uri:
154
- bridge_path = tempfile.mktemp(suffix=".mp4")
155
- download_blob(uri, bridge_path)
156
- elif video_bytes:
157
- bridge_path = save_video_bytes(video_bytes)
158
-
159
- if bridge_path:
160
- update_job_status(job_id, "stitching", 85, "Stitching...", video_url=bridge_path)
161
- final_cut = os.path.join("outputs", f"{job_id}_merged_temp.mp4")
162
- merged_path = stitch_videos(path_a, bridge_path, path_c, final_cut)
163
-
164
- msg = "Done! (Merged)" if merged_path else "Done! (Bridge Only)"
165
- update_job_status(job_id, "completed", 100, msg, video_url=bridge_path, merged_video_url=merged_path)
166
- return
167
- else:
168
- raise Exception("No videos returned.")
169
  else:
170
- raise Exception("GCP_PROJECT_ID not set.")
171
-
172
  except Exception as e:
173
- logger.error(f"Gen Fatal: {e}")
174
- update_job_status(job_id, "error", 0, f"Error: {e}")
175
- job_failed = True
176
- finally:
177
- if not job_failed:
178
- try:
179
- with open(f"outputs/{job_id}.json", "r") as f:
180
- if json.load(f).get("status") not in ["completed", "error"]:
181
- update_job_status(job_id, "error", 0, "Job timed out.")
182
- except: pass
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
+ # Fix class to satisfy SDK requirements for polling
17
+ class GetOpRequest:
18
+ def __init__(self, name):
19
+ self.name = name
20
+
21
+
22
  def get_file_hash(filepath):
23
  hash_md5 = hashlib.md5()
24
  with open(filepath, "rb") as f:
 
36
  return f
37
  except Exception:
38
  pass
 
39
  return client.files.upload(file=filepath, config={'display_name': file_hash})
40
 
41
 
42
  def analyze_only(path_a, path_c, job_id=None):
43
+ update_job_status(job_id, "analyzing", 10, "Analyzing scenes...")
44
  client = genai.Client(api_key=Settings.GOOGLE_API_KEY)
 
45
  try:
46
  file_a = get_or_upload_file(client, path_a)
47
  file_c = get_or_upload_file(client, path_c)
 
48
  while file_a.state.name == "PROCESSING" or file_c.state.name == "PROCESSING":
 
49
  time.sleep(2)
50
  file_a = client.files.get(name=file_a.name)
51
  file_c = client.files.get(name=file_c.name)
52
 
 
 
53
  res = client.models.generate_content(
54
  model="gemini-2.0-flash",
55
+ contents=["Analyze Video A and C. Return JSON with analysis_a, analysis_c, visual_prompt_b.", file_a, file_c],
56
  config=types.GenerateContentConfig(response_mime_type="application/json")
57
  )
 
58
  text = res.text.strip()
59
+ if "```json" in text:
60
+ text = text.split("```json")[1].split("```")[0]
61
+ data = json.loads(text)
62
+ return {"prompt": data.get("visual_prompt_b", text), "status": "success"}
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
 
64
  return {"detail": str(e), "status": "error"}
65
 
66
 
67
  def generate_only(prompt, path_a, path_c, job_id, style, audio, neg, guidance, motion):
68
+ update_job_status(job_id, "generating", 50, "Production started...")
69
  full_prompt = f"{style} style. {prompt} Soundtrack: {audio}"
 
 
 
 
70
  try:
71
+ if not Settings.GCP_PROJECT_ID:
72
+ raise Exception("GCP_PROJECT_ID missing.")
73
+ client = genai.Client(vertexai=True, project=Settings.GCP_PROJECT_ID, location=Settings.GCP_LOCATION)
74
+
75
+ # Start Job
76
+ op = client.models.generate_videos(
77
+ model='veo-3.1-generate-preview',
78
+ prompt=full_prompt,
79
+ config=types.GenerateVideosConfig(number_of_videos=1)
80
+ )
81
+
82
+ # Extract ID and create the proxy object for the SDK
83
+ op_name = op.name if hasattr(op, 'name') else str(op)
84
+ request_proxy = GetOpRequest(op_name)
85
+ logger.info(f"Polling Job: {op_name}")
86
+ start_time = time.time()
87
+ while True:
88
+ if time.time() - start_time > 600:
89
+ raise Exception("Timeout (10m).")
90
 
91
+ # REFRESH logic using the proxy object
92
+ try:
93
+ op = client.operations.get(request_proxy)
94
+ except Exception as e:
95
+ logger.warning(f"Polling error: {e}")
96
+ time.sleep(20)
97
+ continue
98
 
99
+ # Check Status
100
+ is_done = getattr(op, 'done', False)
101
+ if is_done:
102
+ logger.info("Generation Done.")
103
+ break
104
 
105
+ logger.info("Waiting for Veo...")
106
+ time.sleep(20)
107
+
108
+ # Process Result
109
+ res_val = op.result
110
+ result = res_val() if callable(res_val) else res_val
111
+
112
+ if result and (getattr(result, 'generated_videos', None) or 'generated_videos' in result):
113
+ vid = result.generated_videos[0] if hasattr(result, 'generated_videos') else result['generated_videos'][0]
114
+ bridge_path = tempfile.mktemp(suffix=".mp4")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ if hasattr(vid.video, 'uri') and vid.video.uri:
117
+ download_blob(vid.video.uri, bridge_path)
118
+ else:
119
+ bridge_path = save_video_bytes(vid.video.video_bytes)
 
 
 
 
 
 
120
 
121
+ update_job_status(job_id, "stitching", 85, "Stitching...")
122
+ final_cut = os.path.join("outputs", f"{job_id}_final.mp4")
123
+ merged_path = stitch_videos(path_a, bridge_path, path_c, final_cut)
 
124
 
125
+ msg = "Done! (Merged)" if merged_path else "Done! (Bridge Only)"
126
+ update_job_status(job_id, "completed", 100, msg, video_url=bridge_path, merged_video_url=merged_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  else:
128
+ raise Exception("No video output.")
 
129
  except Exception as e:
130
+ logger.error(f"Error: {e}")
131
+ update_job_status(job_id, "error", 0, str(e))