Gaurav vashistha commited on
Commit
a8c659a
·
1 Parent(s): e4a4025

refactor: update agent and server code, move agent.py to root

Browse files
Files changed (3) hide show
  1. agent.py +144 -132
  2. continuity_agent/agent.py +0 -256
  3. server.py +1 -1
agent.py CHANGED
@@ -1,21 +1,23 @@
1
  import os
 
 
 
 
 
2
  from typing import TypedDict, Optional
3
  from langgraph.graph import StateGraph, END
4
- from langchain_google_genai import ChatGoogleGenerativeAI
5
  from google import genai
 
6
  from gradio_client import Client, handle_file
7
- import shutil
8
- import requests
9
- import tempfile
10
- import os
11
- import shutil
12
- import requests
13
- import tempfile
14
-
15
  from dotenv import load_dotenv
16
 
 
17
  load_dotenv()
18
 
 
 
 
 
19
  # State Definition
20
  class ContinuityState(TypedDict):
21
  video_a_url: str
@@ -27,27 +29,25 @@ class ContinuityState(TypedDict):
27
  video_a_local_path: Optional[str]
28
  video_c_local_path: Optional[str]
29
 
30
- # Node 1: Analyst
 
 
 
 
 
 
 
 
 
 
31
  def analyze_videos(state: ContinuityState) -> dict:
32
- print("--- 🧐 Analyst Node (Director) ---")
33
 
34
  video_a_url = state['video_a_url']
35
  video_c_url = state['video_c_url']
36
 
37
- # Initialize Google GenAI Client
38
- client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
39
-
40
  try:
41
- # Download videos to temp files for analysis
42
- def download_to_temp(url):
43
- print(f"Downloading: {url}")
44
- resp = requests.get(url, stream=True)
45
- resp.raise_for_status()
46
- suffix = os.path.splitext(url.split("/")[-1])[1] or ".mp4"
47
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
48
- shutil.copyfileobj(resp.raw, f)
49
- return f.name
50
-
51
  path_a = state.get('video_a_local_path')
52
  if not path_a:
53
  path_a = download_to_temp(video_a_url)
@@ -55,94 +55,103 @@ def analyze_videos(state: ContinuityState) -> dict:
55
  path_c = state.get('video_c_local_path')
56
  if not path_c:
57
  path_c = download_to_temp(video_c_url)
58
-
59
- print("Uploading videos to Gemini...")
60
- file_a = client.files.upload(file=path_a)
61
- file_c = client.files.upload(file=path_c)
62
-
63
- # Wait for processing? Usually quick for small files, but good practice to check state if needed.
64
- # For simplicity in this agent, assuming ready or waiting implicitly.
65
- # (Gemini 1.5 Flash usually processes quickly)
66
-
67
- prompt = """
68
- You are a film director.
69
- Analyze the motion, lighting, and subject of the first video (Video A) and the second video (Video C).
70
- Write a detailed visual prompt for a 2-second video (Video B) that smoothly transitions from the end of A to the start of C.
71
- Target Output: A single concise descriptive paragraph for the video generation model.
72
- """
73
-
74
- print("Generating transition prompt...")
75
- response = client.models.generate_content(
76
- model="gemini-1.5-flash",
77
- contents=[prompt, file_a, file_c]
78
- )
79
-
80
- transition_prompt = response.text
81
- print(f"Generated Prompt: {transition_prompt}")
82
-
83
- # Cleanup uploaded files from local ? (Files on server stay for 48h or until deleted)
84
- # client.files.delete(name=file_a.name)
85
- # client.files.delete(name=file_c.name)
86
-
87
- # We also need these local paths for the Generator node to extract frames!
88
- # Pass them in state or re-download? Better to pass paths if possible, but
89
- # State definition expects URLs. We can add temp paths to state or re-download.
90
- # Let's add temp paths to state for efficiency.
91
-
92
- return {
93
- "scene_analysis": transition_prompt,
94
- "veo_prompt": transition_prompt,
95
- "video_a_local_path": path_a,
96
- "video_c_local_path": path_c
97
- }
98
-
99
  except Exception as e:
100
- print(f"Error in Analyst: {e}")
101
- return {"scene_analysis": f"Error: {str(e)}", "veo_prompt": "Error"}
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # Node 2: Generator (Wan 2.2 First Last Frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def generate_video(state: ContinuityState) -> dict:
106
- print("--- 🎥 Generator Node (Wan 2.2) ---")
107
 
108
  prompt = state.get('veo_prompt', "")
109
  path_a = state.get('video_a_local_path')
110
  path_c = state.get('video_c_local_path')
111
 
112
  if not path_a or not path_c:
113
- # Fallback if dependencies failed or state clean
114
- # Re-download logic would go here, but assuming flow works
115
  return {"generated_video_url": "Error: Missing local video paths"}
116
 
117
  try:
118
- # Extract Frames
119
  import cv2
120
  from PIL import Image
121
 
122
  def get_frame(video_path, location="last"):
123
  cap = cv2.VideoCapture(video_path)
124
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
125
- if location == "last":
126
- cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
127
- else: # first
128
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
129
-
130
  ret, frame = cap.read()
131
  cap.release()
132
-
133
- if ret:
134
- # Convert BGR to RGB
135
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
136
- return Image.fromarray(frame_rgb)
137
- else:
138
- raise ValueError(f"Could not extract frame from {video_path}")
139
 
140
- print("Extracting frames...")
141
  img_start = get_frame(path_a, "last")
142
  img_end = get_frame(path_c, "first")
143
 
144
- # Save frames to temp files for Gradio Client (it handles file paths better than PIL objects usually)
145
- # Although client.predict might take PIL, handle_file is safer with paths.
146
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f_start:
147
  img_start.save(f_start, format="PNG")
148
  start_path = f_start.name
@@ -151,61 +160,64 @@ def generate_video(state: ContinuityState) -> dict:
151
  img_end.save(f_end, format="PNG")
152
  end_path = f_end.name
153
 
154
- # Call Wan 2.2
155
- print("Initializing Wan Client...")
156
- client = Client("multimodalart/wan-2-2-first-last-frame")
157
-
158
- print(f"Generating transition with prompt: {prompt[:50]}...")
159
- # predict(start_image, end_image, prompt, negative_prompt, duration, steps, guide, guide2, seed, rand, api_name)
160
- result = client.predict(
161
- start_image_pil=handle_file(start_path),
162
- end_image_pil=handle_file(end_path),
163
- prompt=prompt,
164
- negative_prompt="blurry, distorted, low quality, static",
165
- duration_seconds=2.1,
166
- steps=20, # Default is often around 20-30 for good quality
167
- guidance_scale=5.0,
168
- guidance_scale_2=5.0,
169
- seed=42,
170
- randomize_seed=True,
171
- api_name="/generate_video"
172
- )
173
-
174
- # Clean up temp frames and videos
175
  try:
176
- os.remove(start_path)
177
- os.remove(end_path)
178
- os.remove(path_a)
179
- os.remove(path_c)
180
- except:
181
- pass
182
-
183
- # Parse valid result
184
- # Expected: ({'video': path, ...}, seed) or just path depending on version
185
- # Based on inspection: (generated_video_mp4, seed)
186
- video_out = result[0]
187
- if isinstance(video_out, dict) and 'video' in video_out:
188
- return {"generated_video_url": video_out['video']}
189
- elif isinstance(video_out, str) and os.path.exists(video_out):
190
- return {"generated_video_url": video_out}
191
- else:
192
- return {"generated_video_url": f"Error: Unexpected output {result}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  except Exception as e:
195
- print(f"Error in Generator: {e}")
196
  return {"generated_video_url": f"Error: {str(e)}"}
197
 
198
 
199
  # Graph Construction
200
  workflow = StateGraph(ContinuityState)
201
-
202
  workflow.add_node("analyst", analyze_videos)
203
- # workflow.add_node("prompter", draft_prompt) # Skipped, Analyst does extraction + prompting
204
  workflow.add_node("generator", generate_video)
205
-
206
  workflow.set_entry_point("analyst")
207
-
208
  workflow.add_edge("analyst", "generator")
209
  workflow.add_edge("generator", END)
210
-
211
- app = workflow.compile()
 
1
  import os
2
+ import time
3
+ import shutil
4
+ import requests
5
+ import tempfile
6
+ import logging
7
  from typing import TypedDict, Optional
8
  from langgraph.graph import StateGraph, END
 
9
  from google import genai
10
+ from groq import Groq
11
  from gradio_client import Client, handle_file
 
 
 
 
 
 
 
 
12
  from dotenv import load_dotenv
13
 
14
+ # Load environment variables
15
  load_dotenv()
16
 
17
+ # Configure Logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
  # State Definition
22
  class ContinuityState(TypedDict):
23
  video_a_url: str
 
29
  video_a_local_path: Optional[str]
30
  video_c_local_path: Optional[str]
31
 
32
+ # --- HELPER FUNCTIONS ---
33
+ def download_to_temp(url):
34
+ logger.info(f"Downloading: {url}")
35
+ resp = requests.get(url, stream=True)
36
+ resp.raise_for_status()
37
+ suffix = os.path.splitext(url.split("/")[-1])[1] or ".mp4"
38
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
39
+ shutil.copyfileobj(resp.raw, f)
40
+ return f.name
41
+
42
+ # --- NODE 1: ANALYST ---
43
  def analyze_videos(state: ContinuityState) -> dict:
44
+ logger.info("--- 🧐 Analyst Node (Director) ---")
45
 
46
  video_a_url = state['video_a_url']
47
  video_c_url = state['video_c_url']
48
 
49
+ # 1. Prepare Files
 
 
50
  try:
 
 
 
 
 
 
 
 
 
 
51
  path_a = state.get('video_a_local_path')
52
  if not path_a:
53
  path_a = download_to_temp(video_a_url)
 
55
  path_c = state.get('video_c_local_path')
56
  if not path_c:
57
  path_c = download_to_temp(video_c_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
+ logger.error(f"Download failed: {e}")
60
+ return {"scene_analysis": "Error downloading", "veo_prompt": "Smooth cinematic transition"}
61
 
62
+ # 2. Try Gemini 2.0 (With Retry)
63
+ client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
64
+ transition_prompt = None
65
+
66
+ retries = 3
67
+ for attempt in range(retries):
68
+ try:
69
+ logger.info(f"Uploading videos to Gemini... (Attempt {attempt+1})")
70
+ file_a = client.files.upload(file=path_a)
71
+ file_c = client.files.upload(file=path_c)
72
+
73
+ prompt_text = """
74
+ You are a film director.
75
+ Analyze the motion, lighting, and subject of the first video (Video A) and the second video (Video C).
76
+ Write a detailed visual prompt for a 2-second video (Video B) that smoothly transitions from the end of A to the start of C.
77
+ Target Output: A single concise descriptive paragraph for the video generation model.
78
+ """
79
+
80
+ logger.info("Generating transition prompt...")
81
+ # Using 2.0 Flash as per your logs (or 1.5-flash if preferred)
82
+ response = client.models.generate_content(
83
+ model="gemini-2.0-flash-exp",
84
+ contents=[prompt_text, file_a, file_c]
85
+ )
86
+ transition_prompt = response.text
87
+ logger.info(f"Generated Prompt: {transition_prompt}")
88
+ break # Success
89
+
90
+ except Exception as e:
91
+ if "429" in str(e) or "RESOURCE_EXHAUSTED" in str(e):
92
+ wait = 30 * (attempt + 1)
93
+ logger.warning(f"⚠️ Gemini Quota 429. Retrying in {wait}s...")
94
+ time.sleep(wait)
95
+ else:
96
+ logger.error(f"⚠️ Gemini Error: {e}")
97
+ break
98
 
99
+ # 3. Fallback: Groq (If Gemini failed)
100
+ if not transition_prompt:
101
+ logger.info("Switching to Llama 3.2 (Groq) Fallback...")
102
+ try:
103
+ groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
104
+ # We can't easily send videos, so we generate a prompt based on general best practices
105
+ fallback_prompt = "Create a smooth, cinematic visual transition that bridges two scenes with matching lighting and motion blur."
106
+
107
+ completion = groq_client.chat.completions.create(
108
+ model="llama-3.2-11b-vision-preview",
109
+ messages=[
110
+ {"role": "user", "content": f"Refine this into a video generation prompt: {fallback_prompt}"}
111
+ ]
112
+ )
113
+ transition_prompt = completion.choices[0].message.content
114
+ except Exception as e:
115
+ logger.error(f"❌ Groq also failed: {e}")
116
+ transition_prompt = "Smooth cinematic transition with motion blur matching the scenes."
117
+
118
+ return {
119
+ "scene_analysis": transition_prompt,
120
+ "veo_prompt": transition_prompt,
121
+ "video_a_local_path": path_a,
122
+ "video_c_local_path": path_c
123
+ }
124
+
125
+ # --- NODE 2: GENERATOR ---
126
  def generate_video(state: ContinuityState) -> dict:
127
+ logger.info("--- 🎥 Generator Node ---")
128
 
129
  prompt = state.get('veo_prompt', "")
130
  path_a = state.get('video_a_local_path')
131
  path_c = state.get('video_c_local_path')
132
 
133
  if not path_a or not path_c:
 
 
134
  return {"generated_video_url": "Error: Missing local video paths"}
135
 
136
  try:
137
+ # Extract Frames (simplified for brevity, ensuring libraries are imported)
138
  import cv2
139
  from PIL import Image
140
 
141
  def get_frame(video_path, location="last"):
142
  cap = cv2.VideoCapture(video_path)
143
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
144
+ if location == "last": cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
145
+ else: cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
 
 
 
146
  ret, frame = cap.read()
147
  cap.release()
148
+ if ret: return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
149
+ raise ValueError(f"Could not extract frame from {video_path}")
 
 
 
 
 
150
 
151
+ logger.info("Extracting frames...")
152
  img_start = get_frame(path_a, "last")
153
  img_end = get_frame(path_c, "first")
154
 
 
 
155
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f_start:
156
  img_start.save(f_start, format="PNG")
157
  start_path = f_start.name
 
160
  img_end.save(f_end, format="PNG")
161
  end_path = f_end.name
162
 
163
+ # --- ATTEMPT 1: WAN 2.2 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  try:
165
+ logger.info("Initializing Wan Client...")
166
+ client = Client("multimodalart/wan-2-2-first-last-frame")
167
+
168
+ logger.info(f"Generating with Wan 2.2... Prompt: {prompt[:30]}...")
169
+ result = client.predict(
170
+ start_image_pil=handle_file(start_path),
171
+ end_image_pil=handle_file(end_path),
172
+ prompt=prompt,
173
+ negative_prompt="blurry, distorted, low quality, static",
174
+ duration_seconds=2.1,
175
+ steps=20,
176
+ guidance_scale=5.0,
177
+ guidance_scale_2=5.0,
178
+ seed=42,
179
+ randomize_seed=True,
180
+ api_name="/generate_video"
181
+ )
182
+ # Handle Wan output format
183
+ video_out = result[0]
184
+ if isinstance(video_out, dict) and 'video' in video_out:
185
+ return {"generated_video_url": video_out['video']}
186
+ elif isinstance(video_out, str) and os.path.exists(video_out):
187
+ return {"generated_video_url": video_out}
188
+
189
+ except Exception as e:
190
+ logger.warning(f"⚠️ Wan 2.2 Failed: {e}")
191
+
192
+ # --- ATTEMPT 2: SVD FALLBACK ---
193
+ logger.info("🔄 Switching to SVD Fallback...")
194
+ try:
195
+ # FIXED REPO ID
196
+ client = Client("multimodalart/stable-video-diffusion")
197
+
198
+ # SVD uses one image, we'll use the start frame
199
+ result = client.predict(
200
+ handle_file(start_path),
201
+ 0.0, 0.0, 1, 25, # resized_width, resized_height, motion_bucket_id, fps
202
+ api_name="/predict"
203
+ )
204
+ logger.info(f"✅ SVD Generated: {result}")
205
+ return {"generated_video_url": result} # SVD usually returns path string
206
+
207
+ except Exception as e:
208
+ logger.error(f"❌ All Generators Failed. Error: {e}")
209
+ return {"generated_video_url": f"Error: {str(e)}"}
210
 
211
  except Exception as e:
212
+ logger.error(f"Error in Generator Setup: {e}")
213
  return {"generated_video_url": f"Error: {str(e)}"}
214
 
215
 
216
  # Graph Construction
217
  workflow = StateGraph(ContinuityState)
 
218
  workflow.add_node("analyst", analyze_videos)
 
219
  workflow.add_node("generator", generate_video)
 
220
  workflow.set_entry_point("analyst")
 
221
  workflow.add_edge("analyst", "generator")
222
  workflow.add_edge("generator", END)
223
+ app = workflow.compile()
 
continuity_agent/agent.py DELETED
@@ -1,256 +0,0 @@
1
- import os
2
- import time
3
- import shutil
4
- import cv2
5
- import numpy as np
6
- import base64
7
- import tempfile
8
- from groq import Groq
9
- from google import genai
10
- from gradio_client import Client, handle_file
11
- from dotenv import load_dotenv
12
-
13
- load_dotenv()
14
-
15
- # --- HELPER: Filmstrip Engine ---
16
- def create_filmstrip(video_path, samples=5, is_start=False):
17
- """Extracts frames and stitches them into a filmstrip for Vision analysis."""
18
- try:
19
- cap = cv2.VideoCapture(video_path)
20
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
21
- fps = cap.get(cv2.CAP_PROP_FPS)
22
- duration = total_frames / fps
23
-
24
- # Determine extraction points
25
- if is_start: # First 2 seconds
26
- start_f = 0
27
- end_f = int(min(total_frames, 2 * fps))
28
- if end_f <= start_f: end_f = total_frames # Handle short videos
29
- else: # Last 2 seconds
30
- start_f = int(max(0, total_frames - 2 * fps))
31
- end_f = total_frames
32
- if start_f >= end_f: start_f = 0
33
-
34
- frame_indices = np.linspace(start_f, end_f - 1, samples, dtype=int)
35
- frames = []
36
-
37
- for idx in frame_indices:
38
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
39
- ret, frame = cap.read()
40
- if ret:
41
- # Resize for token efficiency (Height 300px)
42
- h, w, _ = frame.shape
43
- scale = 300 / h
44
- new_w = int(w * scale)
45
- frame = cv2.resize(frame, (new_w, 300))
46
- frames.append(frame)
47
- cap.release()
48
-
49
- if not frames:
50
- raise ValueError("No frames extracted")
51
-
52
- # Stitch horizontally
53
- filmstrip = cv2.hconcat(frames)
54
-
55
- # Use a consistent temp file pattern or unique name
56
- temp_dir = tempfile.gettempdir()
57
- output_path = os.path.join(temp_dir, f"temp_strip_{int(time.time())}_{'start' if is_start else 'end'}.jpg")
58
- cv2.imwrite(output_path, filmstrip)
59
- return output_path
60
- except Exception as e:
61
- print(f"⚠️ Filmstrip failed: {e}")
62
- return None
63
-
64
- # --- PHASE 1: ANALYZE ONLY ---
65
- def analyze_only(video_a_path: str, video_c_path: str):
66
- print(f"🎬 Analyst: Processing videos...")
67
-
68
- # Generate Filmstrips
69
- strip_a = create_filmstrip(video_a_path, is_start=False)
70
- strip_c = create_filmstrip(video_c_path, is_start=True)
71
-
72
- if not strip_a or not strip_c:
73
- return {
74
- "prompt": "Cinematic transition between scenes.",
75
- "video_a_path": video_a_path,
76
- "video_c_path": video_c_path,
77
- "status": "warning",
78
- "detail": "Could not create filmstrips"
79
- }
80
-
81
- prompt = "Smooth cinematic transition." # Default safety
82
-
83
- # 1. Try Gemini 2.0 (Primary)
84
- try:
85
- print("🤖 Engaging Gemini 2.0...")
86
- client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY")) # Using correct env var name
87
-
88
- file_a = client.files.upload(file=strip_a)
89
- file_c = client.files.upload(file=strip_c)
90
-
91
- system_prompt = """
92
- You are an expert film editor. Analyze these two 'filmstrips'.
93
- Image 1 shows the end of the first clip (time flows left-to-right).
94
- Image 2 shows the start of the next clip.
95
- Describe the motion, lighting, and subject connection required to seamlessly bridge A to C in a cinematic way.
96
- Output a SINGLE concise paragraph for a video generation model.
97
- """
98
-
99
- response = client.models.generate_content(
100
- model="gemini-2.0-flash",
101
- contents=[system_prompt, file_a, file_c]
102
- )
103
-
104
- if response.text:
105
- prompt = response.text
106
-
107
- # raise Exception("Force Fallback for Testing") # Commented out for production use unless specifically testing
108
-
109
- except Exception as e:
110
- print(f"⚠️ Gemini Quota/Error: {e}. Switching to Llama 3.2 (Groq)...")
111
-
112
- # 2. Try Groq (Fallback)
113
- try:
114
- groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
115
-
116
- def encode_image(image_path):
117
- with open(image_path, "rb") as image_file:
118
- return base64.b64encode(image_file.read()).decode('utf-8')
119
-
120
- b64_a = encode_image(strip_a)
121
- b64_c = encode_image(strip_c)
122
-
123
- completion = groq_client.chat.completions.create(
124
- model="llama-3.2-11b-vision-instruct",
125
- messages=[
126
- {
127
- "role": "user",
128
- "content": [
129
- {"type": "text", "text": "These images show the END of Clip A and START of Clip C. Describe a smooth visual transition to bridge them."},
130
- {
131
- "type": "image_url",
132
- "image_url": {"url": f"data:image/jpeg;base64,{b64_a}"}
133
- },
134
- {
135
- "type": "image_url",
136
- "image_url": {"url": f"data:image/jpeg;base64,{b64_c}"}
137
- }
138
- ]
139
- }
140
- ],
141
- temperature=0.7,
142
- max_tokens=500
143
- )
144
- prompt = completion.choices[0].message.content
145
- except Exception as groq_e:
146
- print(f"❌ Groq also failed: {groq_e}. Using default prompt.")
147
-
148
- # Cleanup
149
- try:
150
- if os.path.exists(strip_a): os.remove(strip_a)
151
- if os.path.exists(strip_c): os.remove(strip_c)
152
- except:
153
- pass
154
-
155
- return {
156
- "prompt": prompt,
157
- "video_a_path": video_a_path,
158
- "video_c_path": video_c_path,
159
- "status": "success"
160
- }
161
-
162
- # --- PHASE 2: GENERATE ONLY ---
163
- def generate_only(prompt: str, video_a_path: str, video_c_path: str):
164
- print(f"🎥 Generator: Action! Prompt: {prompt[:50]}...")
165
-
166
- # 1. Primary: Wan 2.2
167
- try:
168
- # Extract Frames for Wan
169
- # We need to save temporary frames because handle_file expects a path
170
- def get_frame(v_path, at_start):
171
- cap = cv2.VideoCapture(v_path)
172
- if not at_start:
173
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
- cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, total-1))
175
- ret, frame = cap.read()
176
- cap.release()
177
- if not ret: raise ValueError("Frame extract failed")
178
-
179
- # Resize safe for Wan
180
- h, w = frame.shape[:2]
181
- if h > 480:
182
- scale = 480/h
183
- frame = cv2.resize(frame, (int(w*scale), 480))
184
-
185
- t_path = os.path.join(tempfile.gettempdir(), f"wan_frame_{int(time.time())}_{'s' if at_start else 'e'}.png")
186
- cv2.imwrite(t_path, frame)
187
- return t_path
188
-
189
- f_start = get_frame(video_a_path, False) # Last frame of A
190
- f_end = get_frame(video_c_path, True) # First frame of C
191
-
192
- client = Client("multimodalart/wan-2-2-first-last-frame", token=os.environ.get("HF_TOKEN"))
193
-
194
- print("Generating with Wan 2.2...")
195
- result = client.predict(
196
- start_image_pil=handle_file(f_start),
197
- end_image_pil=handle_file(f_end),
198
- prompt=prompt,
199
- negative_prompt="blurry, distorted, low quality, static",
200
- duration_seconds=2.1,
201
- steps=20,
202
- guidance_scale=5.0,
203
- guidance_scale_2=5.0,
204
- seed=42,
205
- randomize_seed=True,
206
- api_name="/generate_video"
207
- )
208
-
209
- # Cleanup temp
210
- try:
211
- os.remove(f_start)
212
- os.remove(f_end)
213
- except: pass
214
-
215
- # Parse result
216
- video_out = result[0]
217
- if isinstance(video_out, dict) and 'video' in video_out:
218
- return {"video_url": video_out['video']}
219
- elif isinstance(video_out, str):
220
- return {"video_url": video_out}
221
- else:
222
- raise ValueError(f"Unknown Wan output: {result}")
223
-
224
- except Exception as e:
225
- print(f"⚠️ Wan 2.2 Failed (Quota/Error): {e}")
226
- print("🔄 Switching to SVD Fallback...")
227
-
228
- # 2. Fallback: SVD (Image-to-Video)
229
- try:
230
- client_svd = Client("stabilityai/stable-video-diffusion-img2vid-xt-1-1")
231
-
232
- # Extract last frame of A for SVD input
233
- cap = cv2.VideoCapture(video_a_path)
234
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
235
- cap.set(cv2.CAP_PROP_POS_FRAMES, total-1)
236
- ret, frame = cap.read()
237
- cap.release()
238
-
239
- # Resize for SVD (1024x576 recommended or similar 16:9)
240
- frame = cv2.resize(frame, (1024, 576))
241
-
242
- svd_input_path = os.path.join(tempfile.gettempdir(), "svd_input.jpg")
243
- cv2.imwrite(svd_input_path, frame)
244
-
245
- print("Generating with SVD...")
246
- result = client_svd.predict(
247
- svd_input_path,
248
- 0.0, 127, 6,
249
- api_name="/predict"
250
- )
251
- return {"video_url": result}
252
-
253
- except Exception as svd_e:
254
- err_msg = f"All Generators Failed. Wan: {e}, SVD: {svd_e}"
255
- print(f"❌ {err_msg}")
256
- return {"video_url": f"Error: {err_msg}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server.py CHANGED
@@ -71,7 +71,7 @@ async def generate_endpoint(
71
 
72
  # Call Agent
73
  result = generate_only(prompt, video_a_path, video_c_path)
74
- gen_path = result.get("video_url")
75
 
76
  if not gen_path or "Error" in gen_path:
77
  raise HTTPException(status_code=500, detail=f"Generation failed: {gen_path}")
 
71
 
72
  # Call Agent
73
  result = generate_only(prompt, video_a_path, video_c_path)
74
+ gen_path = result.get("generated_video_url")
75
 
76
  if not gen_path or "Error" in gen_path:
77
  raise HTTPException(status_code=500, detail=f"Generation failed: {gen_path}")