Text Generation
video-editing
social-media
agent
tool-calling
sft
trl
viralcut
ryu34 commited on
Commit
0307f85
·
verified ·
1 Parent(s): 81a3e4f

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +489 -0
agent.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ViralCut Agent - Runtime
3
+ ========================
4
+ The actual agent that uses the fine-tuned model to edit videos autonomously.
5
+
6
+ This connects the trained model to real tools:
7
+ - FFmpeg for video editing
8
+ - DuckDuckGo for web search (free, no API key)
9
+ - Whisper for transcription
10
+ - PySceneDetect for shot detection
11
+
12
+ Usage:
13
+ python agent.py --video raw_footage.mp4 --platform tiktok --niche food
14
+ python agent.py --plan --niche "coffee shop" --platform tiktok
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import re
21
+ import subprocess
22
+ import sys
23
+ import tempfile
24
+ from pathlib import Path
25
+
26
+ # ============================================================
27
+ # TOOL IMPLEMENTATIONS
28
+ # ============================================================
29
+
30
+ class FFmpegTool:
31
+ """Execute FFmpeg commands for video/audio processing."""
32
+
33
+ @staticmethod
34
+ def run(command: str, description: str = "") -> str:
35
+ """Execute an FFmpeg command and return result."""
36
+ print(f" 🎬 FFmpeg: {description}")
37
+ print(f" $ {command}")
38
+ try:
39
+ result = subprocess.run(
40
+ command, shell=True, capture_output=True, text=True, timeout=120
41
+ )
42
+ if result.returncode == 0:
43
+ return json.dumps({"status": "success", "message": f"Command completed: {description}"})
44
+ else:
45
+ return json.dumps({"status": "error", "message": result.stderr[:500]})
46
+ except subprocess.TimeoutExpired:
47
+ return json.dumps({"status": "error", "message": "Command timed out after 120s"})
48
+ except Exception as e:
49
+ return json.dumps({"status": "error", "message": str(e)})
50
+
51
+
52
+ class WebSearchTool:
53
+ """Search the web using DuckDuckGo (free, no API key needed)."""
54
+
55
+ @staticmethod
56
+ def search(query: str, search_type: str = "general") -> str:
57
+ """Search the web and return results."""
58
+ print(f" 🔍 Searching: {query} (type: {search_type})")
59
+ try:
60
+ from duckduckgo_search import DDGS
61
+ with DDGS() as ddgs:
62
+ results = []
63
+ for r in ddgs.text(query, max_results=5):
64
+ results.append({
65
+ "title": r.get("title", ""),
66
+ "url": r.get("href", ""),
67
+ "description": r.get("body", "")[:200]
68
+ })
69
+ return json.dumps({"results": results})
70
+ except ImportError:
71
+ return json.dumps({"results": [{"title": "Install duckduckgo-search", "description": "pip install duckduckgo-search"}]})
72
+ except Exception as e:
73
+ return json.dumps({"results": [], "error": str(e)})
74
+
75
+
76
+ class VideoAnalyzer:
77
+ """Analyze video files using ffprobe and PySceneDetect."""
78
+
79
+ @staticmethod
80
+ def analyze(video_path: str, analysis_type: str = "full") -> str:
81
+ """Analyze a video file."""
82
+ print(f" 📊 Analyzing: {video_path} ({analysis_type})")
83
+
84
+ if not os.path.exists(video_path):
85
+ return json.dumps({"error": f"File not found: {video_path}"})
86
+
87
+ result = {}
88
+
89
+ # Get basic info via ffprobe
90
+ try:
91
+ probe = subprocess.run(
92
+ f'ffprobe -v quiet -print_format json -show_format -show_streams "{video_path}"',
93
+ shell=True, capture_output=True, text=True
94
+ )
95
+ if probe.returncode == 0:
96
+ info = json.loads(probe.stdout)
97
+ fmt = info.get("format", {})
98
+ result["duration"] = float(fmt.get("duration", 0))
99
+ result["size_mb"] = round(int(fmt.get("size", 0)) / 1024 / 1024, 1)
100
+
101
+ for stream in info.get("streams", []):
102
+ if stream.get("codec_type") == "video":
103
+ result["resolution"] = f"{stream.get('width')}x{stream.get('height')}"
104
+ result["fps"] = eval(stream.get("r_frame_rate", "30/1"))
105
+ result["codec"] = stream.get("codec_name")
106
+ elif stream.get("codec_type") == "audio":
107
+ result["audio_codec"] = stream.get("codec_name")
108
+ result["audio_channels"] = stream.get("channels")
109
+ except Exception as e:
110
+ result["probe_error"] = str(e)
111
+
112
+ # Scene detection
113
+ if analysis_type in ("full", "scenes"):
114
+ try:
115
+ from scenedetect import open_video, SceneManager
116
+ from scenedetect.detectors import ContentDetector
117
+
118
+ video = open_video(video_path)
119
+ scene_manager = SceneManager()
120
+ scene_manager.add_detector(ContentDetector(threshold=27))
121
+ scene_manager.detect_scenes(video)
122
+ scene_list = scene_manager.get_scene_list()
123
+
124
+ result["scenes"] = []
125
+ for i, (start, end) in enumerate(scene_list):
126
+ result["scenes"].append({
127
+ "scene": i + 1,
128
+ "start": round(start.get_seconds(), 2),
129
+ "end": round(end.get_seconds(), 2),
130
+ "duration": round((end - start).get_seconds(), 2)
131
+ })
132
+ except ImportError:
133
+ result["scenes_note"] = "Install scenedetect: pip install scenedetect[opencv]"
134
+ except Exception as e:
135
+ result["scenes_error"] = str(e)
136
+
137
+ # Transcript via Whisper
138
+ if analysis_type in ("full", "transcript", "audio"):
139
+ try:
140
+ import whisper
141
+ model = whisper.load_model("base")
142
+ transcript = model.transcribe(video_path)
143
+ result["transcript"] = transcript.get("text", "")[:2000]
144
+ result["segments"] = [
145
+ {"start": s["start"], "end": s["end"], "text": s["text"]}
146
+ for s in transcript.get("segments", [])[:50]
147
+ ]
148
+ except ImportError:
149
+ result["transcript_note"] = "Install whisper: pip install openai-whisper"
150
+ except Exception as e:
151
+ result["transcript_error"] = str(e)
152
+
153
+ return json.dumps(result)
154
+
155
+
156
+ class ViralityScorer:
157
+ """Score video content for viral potential."""
158
+
159
+ @staticmethod
160
+ def score(video_path: str, platform: str, niche: str = "") -> str:
161
+ """Score a video's viral potential based on heuristics."""
162
+ print(f" 📈 Scoring virality: {video_path} for {platform}")
163
+
164
+ # Get video info
165
+ try:
166
+ probe = subprocess.run(
167
+ f'ffprobe -v quiet -print_format json -show_format -show_streams "{video_path}"',
168
+ shell=True, capture_output=True, text=True
169
+ )
170
+ info = json.loads(probe.stdout) if probe.returncode == 0 else {}
171
+ except:
172
+ info = {}
173
+
174
+ duration = float(info.get("format", {}).get("duration", 0))
175
+ has_audio = any(s.get("codec_type") == "audio" for s in info.get("streams", []))
176
+
177
+ # Platform-specific optimal durations
178
+ optimal_ranges = {
179
+ "tiktok": (7, 30),
180
+ "instagram_reels": (15, 30),
181
+ "youtube_shorts": (30, 60)
182
+ }
183
+ opt_min, opt_max = optimal_ranges.get(platform, (15, 60))
184
+
185
+ # Score components
186
+ scores = {}
187
+
188
+ # Length score
189
+ if opt_min <= duration <= opt_max:
190
+ scores["length_optimal"] = 90
191
+ elif duration < opt_min:
192
+ scores["length_optimal"] = max(50, 90 - (opt_min - duration) * 5)
193
+ else:
194
+ scores["length_optimal"] = max(40, 90 - (duration - opt_max) * 3)
195
+
196
+ # Audio presence
197
+ scores["audio_match"] = 80 if has_audio else 30
198
+
199
+ # Resolution check
200
+ for s in info.get("streams", []):
201
+ if s.get("codec_type") == "video":
202
+ h = int(s.get("height", 0))
203
+ w = int(s.get("width", 0))
204
+ if h >= 1920 or w >= 1080:
205
+ scores["visual_quality"] = 85
206
+ elif h >= 1080:
207
+ scores["visual_quality"] = 75
208
+ else:
209
+ scores["visual_quality"] = 55
210
+ # Vertical check
211
+ if h > w:
212
+ scores["format_match"] = 90
213
+ else:
214
+ scores["format_match"] = 50
215
+
216
+ scores.setdefault("visual_quality", 60)
217
+ scores.setdefault("format_match", 60)
218
+ scores["hook_strength"] = 70 # Can't assess without content analysis
219
+ scores["pacing"] = 70
220
+ scores["trend_alignment"] = 65
221
+
222
+ overall = round(sum(scores.values()) / len(scores))
223
+
224
+ suggestions = []
225
+ if scores.get("format_match", 0) < 70:
226
+ suggestions.append("Convert to 9:16 vertical format for better reach")
227
+ if scores.get("length_optimal", 0) < 70:
228
+ suggestions.append(f"Adjust length to {opt_min}-{opt_max}s for {platform}")
229
+ if not has_audio:
230
+ suggestions.append("Add audio - videos without sound get 40% less reach")
231
+
232
+ return json.dumps({
233
+ "overall_score": overall,
234
+ "breakdown": scores,
235
+ "suggestions": suggestions
236
+ })
237
+
238
+
239
+ class CaptionGenerator:
240
+ """Generate platform-optimized captions."""
241
+
242
+ @staticmethod
243
+ def generate(video_description: str, platform: str, tone: str = "casual", include_cta: bool = True) -> str:
244
+ """Generate a caption (using the model itself for this in production)."""
245
+ print(f" ✍️ Generating caption for {platform}")
246
+
247
+ hashtag_sets = {
248
+ "tiktok": ["#fyp", "#viral", "#foryou", "#trending"],
249
+ "instagram": ["#reels", "#explore", "#instagood", "#trending"],
250
+ "youtube": ["#shorts", "#subscribe", "#viral"]
251
+ }
252
+
253
+ base_tags = hashtag_sets.get(platform, ["#viral"])
254
+
255
+ # Extract keywords from description for niche hashtags
256
+ words = video_description.lower().split()
257
+ niche_tags = [f"#{w}" for w in words if len(w) > 3 and w.isalpha()][:3]
258
+
259
+ posting_times = {
260
+ "tiktok": "7-9am, 12-1pm, or 7-9pm in your audience timezone",
261
+ "instagram": "6-9am, 12-2pm, or 5-7pm EST",
262
+ "youtube": "2-4pm or 8-10pm EST"
263
+ }
264
+
265
+ return json.dumps({
266
+ "caption": f"[AI will generate based on: {video_description}]",
267
+ "hashtags": " ".join(base_tags + niche_tags),
268
+ "posting_time": posting_times.get(platform, "Check your analytics"),
269
+ "tip": "Reply to every comment in the first hour - algorithm loves engagement"
270
+ })
271
+
272
+
273
+ class AIDetector:
274
+ """Detect AI-generated content."""
275
+
276
+ @staticmethod
277
+ def detect(content_path: str, check_type: str = "video") -> str:
278
+ """Basic AI content detection heuristics."""
279
+ print(f" 🔬 Checking for AI artifacts: {content_path}")
280
+
281
+ if not os.path.exists(content_path):
282
+ return json.dumps({"error": f"File not found: {content_path}"})
283
+
284
+ # Basic file analysis (real detection would use a classifier model)
285
+ size = os.path.getsize(content_path)
286
+
287
+ return json.dumps({
288
+ "file_analyzed": content_path,
289
+ "check_type": check_type,
290
+ "file_size_mb": round(size / 1024 / 1024, 2),
291
+ "note": "Full AI detection requires DeMamba or VideoScore2 model. Basic file analysis only.",
292
+ "recommendations": [
293
+ "Check for morphing objects between frames",
294
+ "Look for impossible reflections or shadows",
295
+ "Verify text is readable and consistent",
296
+ "Check if camera movement is unnaturally smooth"
297
+ ]
298
+ })
299
+
300
+
301
+ # ============================================================
302
+ # AGENT CORE
303
+ # ============================================================
304
+
305
+ TOOL_MAP = {
306
+ "ffmpeg_cmd": lambda args: FFmpegTool.run(**args),
307
+ "web_search": lambda args: WebSearchTool.search(**args),
308
+ "analyze_video": lambda args: VideoAnalyzer.analyze(**args),
309
+ "score_virality": lambda args: ViralityScorer.score(**args),
310
+ "generate_caption": lambda args: CaptionGenerator.generate(**args),
311
+ "detect_ai_slop": lambda args: AIDetector.detect(**args),
312
+ }
313
+
314
+
315
+ class ViralCutAgent:
316
+ """The main agent that orchestrates video editing using the fine-tuned model."""
317
+
318
+ def __init__(self, model_id="ryu34/viralcut-agent", device="auto"):
319
+ print(f"Loading ViralCut Agent from {model_id}...")
320
+
321
+ from transformers import AutoModelForCausalLM, AutoTokenizer
322
+
323
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
324
+ self.model = AutoModelForCausalLM.from_pretrained(
325
+ model_id,
326
+ device_map=device,
327
+ torch_dtype="auto",
328
+ )
329
+ self.model.eval()
330
+
331
+ # Tool definitions for the chat template
332
+ self.tools = [
333
+ {"type": "function", "function": {"name": "ffmpeg_cmd", "description": "Execute FFmpeg command for video/audio processing.", "parameters": {"type": "object", "properties": {"command": {"type": "string"}, "description": {"type": "string"}}, "required": ["command", "description"]}}},
334
+ {"type": "function", "function": {"name": "web_search", "description": "Search web for royalty-free assets and trends.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}, "search_type": {"type": "string", "enum": ["royalty_free_music", "sound_effects", "trending_content", "general"]}}, "required": ["query", "search_type"]}}},
335
+ {"type": "function", "function": {"name": "analyze_video", "description": "Analyze video for scenes, audio, transcript, quality.", "parameters": {"type": "object", "properties": {"video_path": {"type": "string"}, "analysis_type": {"type": "string", "enum": ["full", "scenes", "audio", "transcript", "quality", "pacing"]}}, "required": ["video_path", "analysis_type"]}}},
336
+ {"type": "function", "function": {"name": "score_virality", "description": "Score video viral potential 0-100.", "parameters": {"type": "object", "properties": {"video_path": {"type": "string"}, "platform": {"type": "string", "enum": ["tiktok", "instagram_reels", "youtube_shorts"]}, "niche": {"type": "string"}}, "required": ["video_path", "platform"]}}},
337
+ {"type": "function", "function": {"name": "generate_caption", "description": "Generate platform-optimized caption with hashtags.", "parameters": {"type": "object", "properties": {"video_description": {"type": "string"}, "platform": {"type": "string", "enum": ["tiktok", "instagram", "youtube"]}, "tone": {"type": "string"}, "include_cta": {"type": "boolean"}}, "required": ["video_description", "platform"]}}},
338
+ {"type": "function", "function": {"name": "detect_ai_slop", "description": "Check content for AI-generated artifacts.", "parameters": {"type": "object", "properties": {"content_path": {"type": "string"}, "check_type": {"type": "string", "enum": ["video", "image", "text", "audio"]}}, "required": ["content_path", "check_type"]}}}
339
+ ]
340
+
341
+ print("Agent ready!")
342
+
343
+ def run(self, user_message: str, max_turns: int = 15):
344
+ """Run the agent on a user request, executing tool calls autonomously."""
345
+
346
+ messages = [
347
+ {"role": "system", "content": "You are ViralCut Agent, an autonomous AI video editor and social media content strategist. You transform raw video footage into professional, viral-worthy social media content. Use your tools to analyze, edit, search, and optimize. Think step-by-step. Always use royalty-free content."},
348
+ {"role": "user", "content": user_message}
349
+ ]
350
+
351
+ print(f"\n{'='*60}")
352
+ print(f"🎬 ViralCut Agent")
353
+ print(f"{'='*60}")
354
+ print(f"User: {user_message}\n")
355
+
356
+ for turn in range(max_turns):
357
+ # Generate response
358
+ text = self.tokenizer.apply_chat_template(
359
+ messages, tools=self.tools, tokenize=False, add_generation_prompt=True
360
+ )
361
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
362
+
363
+ with __import__("torch").no_grad():
364
+ outputs = self.model.generate(
365
+ **inputs,
366
+ max_new_tokens=1024,
367
+ temperature=0.7,
368
+ top_p=0.9,
369
+ do_sample=True,
370
+ )
371
+
372
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=False)
373
+
374
+ # Parse response for tool calls or plain text
375
+ tool_calls = self._parse_tool_calls(response)
376
+
377
+ if tool_calls:
378
+ # Add assistant message with tool calls
379
+ messages.append({"role": "assistant", "tool_calls": tool_calls})
380
+
381
+ # Execute each tool call
382
+ for tc in tool_calls:
383
+ func_name = tc["function"]["name"]
384
+ try:
385
+ args = json.loads(tc["function"]["arguments"])
386
+ except:
387
+ args = {}
388
+
389
+ print(f"\n 🔧 Calling: {func_name}")
390
+
391
+ if func_name in TOOL_MAP:
392
+ result = TOOL_MAP[func_name](args)
393
+ else:
394
+ result = json.dumps({"error": f"Unknown tool: {func_name}"})
395
+
396
+ messages.append({"role": "tool", "name": func_name, "content": result})
397
+ print(f" ✅ Result: {result[:200]}...")
398
+ else:
399
+ # Plain text response - agent is done
400
+ clean = self._clean_response(response)
401
+ messages.append({"role": "assistant", "content": clean})
402
+ print(f"\n🤖 Agent: {clean}")
403
+ break
404
+
405
+ return messages
406
+
407
+ def _parse_tool_calls(self, response: str) -> list:
408
+ """Parse tool calls from model output."""
409
+ tool_calls = []
410
+
411
+ # Qwen tool call format: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
412
+ pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
413
+ matches = re.findall(pattern, response, re.DOTALL)
414
+
415
+ for match in matches:
416
+ try:
417
+ data = json.loads(match)
418
+ tool_calls.append({
419
+ "type": "function",
420
+ "function": {
421
+ "name": data.get("name", ""),
422
+ "arguments": json.dumps(data.get("arguments", {}))
423
+ }
424
+ })
425
+ except json.JSONDecodeError:
426
+ continue
427
+
428
+ return tool_calls
429
+
430
+ def _clean_response(self, response: str) -> str:
431
+ """Clean up model response."""
432
+ # Remove special tokens
433
+ for token in ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]:
434
+ response = response.replace(token, "")
435
+ return response.strip()
436
+
437
+
438
+ # ============================================================
439
+ # CLI
440
+ # ============================================================
441
+
442
+ def main():
443
+ parser = argparse.ArgumentParser(description="ViralCut Agent - AI Video Editor")
444
+ parser.add_argument("--video", type=str, help="Path to raw video file")
445
+ parser.add_argument("--platform", type=str, default="tiktok",
446
+ choices=["tiktok", "instagram", "youtube"],
447
+ help="Target platform")
448
+ parser.add_argument("--niche", type=str, default="", help="Content niche")
449
+ parser.add_argument("--plan", action="store_true", help="Generate content plan only (no video needed)")
450
+ parser.add_argument("--model", type=str, default="ryu34/viralcut-agent", help="Model ID")
451
+ parser.add_argument("--check-slop", type=str, nargs="+", help="Check files for AI-generated content")
452
+
453
+ args = parser.parse_args()
454
+
455
+ if args.check_slop:
456
+ # Quick AI slop check without loading the full model
457
+ for f in args.check_slop:
458
+ result = AIDetector.detect(f, "video")
459
+ print(json.dumps(json.loads(result), indent=2))
460
+ return
461
+
462
+ agent = ViralCutAgent(model_id=args.model)
463
+
464
+ if args.plan:
465
+ niche = args.niche or "general"
466
+ agent.run(f"Research current {args.platform} trends for the '{niche}' niche and create a detailed 7-day content plan with hooks, posting times, and viral strategies.")
467
+ elif args.video:
468
+ if not os.path.exists(args.video):
469
+ print(f"Error: Video file not found: {args.video}")
470
+ sys.exit(1)
471
+ niche_str = f" in the {args.niche} niche" if args.niche else ""
472
+ agent.run(f"I have raw footage at {args.video}. Transform it into a professional, viral {args.platform} video{niche_str}. Analyze it, find the best moments, add trending music, professional edits, and optimize for maximum engagement.")
473
+ else:
474
+ # Interactive mode
475
+ print("ViralCut Agent - Interactive Mode")
476
+ print("Type your request (or 'quit' to exit):\n")
477
+ while True:
478
+ try:
479
+ user_input = input("You: ").strip()
480
+ if user_input.lower() in ("quit", "exit", "q"):
481
+ break
482
+ if user_input:
483
+ agent.run(user_input)
484
+ except (KeyboardInterrupt, EOFError):
485
+ break
486
+
487
+
488
+ if __name__ == "__main__":
489
+ main()