File size: 15,060 Bytes
b12e499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404465b
 
 
 
b12e499
869d082
 
 
 
 
 
 
 
 
 
 
b12e499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869d082
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
LLM-powered script generation for EceMotion Pictures.
Generates intelligent, structure-aware commercial scripts with timing markers.
"""

import logging
import random
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

from config import (
    MODEL_LLM, MODEL_CONFIGS, VOICE_STYLES, STRUCTURE_TEMPLATES, TAGLINES,
    get_safe_model_name
)

logger = logging.getLogger(__name__)

@dataclass
class ScriptSegment:
    """Represents a segment of the commercial script with timing information."""
    text: str
    duration_estimate: float
    segment_type: str  # "hook", "flow", "benefit", "cta"
    timing_marker: Optional[str] = None

@dataclass
class GeneratedScript:
    """Complete generated script with all segments and metadata."""
    segments: List[ScriptSegment]
    total_duration: float
    tagline: str
    voice_style: str
    word_count: int
    raw_script: str

class LLMScriptGenerator:
    """Generates commercial scripts using large language models with fallbacks."""
    
    def __init__(self, model_name: str = MODEL_LLM):
        self.model_name = get_safe_model_name(model_name, "llm")
        self.model = None
        self.tokenizer = None
        self.model_config = MODEL_CONFIGS.get(self.model_name, {})
        self.llm_available = False
        
        # Try to initialize LLM
        self._try_init_llm()
    
    def _try_init_llm(self):
        """Try to initialize the LLM model."""
        try:
            if "dialo" in self.model_name.lower():
                self._init_dialogpt()
            elif "qwen" in self.model_name.lower():
                self._init_qwen()
            else:
                logger.warning(f"Unknown LLM model: {self.model_name}, using fallback")
                self.llm_available = False
        except Exception as e:
            logger.warning(f"Failed to initialize LLM {self.model_name}: {e}")
            self.llm_available = False
    
    def _init_dialogpt(self):
        """Initialize DialoGPT model."""
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM
            
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype="auto",
                device_map="auto" if self._has_gpu() else "cpu"
            )
            self.llm_available = True
            logger.info(f"DialoGPT model {self.model_name} loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to load DialoGPT: {e}")
            self.llm_available = False
    
    def _init_qwen(self):
        """Initialize Qwen model."""
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype="auto",
                device_map="auto" if self._has_gpu() else "cpu",
                trust_remote_code=True
            )
            self.llm_available = True
            logger.info(f"Qwen model {self.model_name} loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to load Qwen: {e}")
            self.llm_available = False
    
    def _has_gpu(self) -> bool:
        """Check if GPU is available."""
        try:
            import torch
            return torch.cuda.is_available()
        except ImportError:
            return False
    
    def _create_system_prompt(self) -> str:
        """Create system prompt for retro commercial script generation."""
        return """You are a professional copywriter specializing in 1980s-style TV commercials. 
Your task is to create engaging, persuasive commercial scripts that capture the authentic retro aesthetic.

Key requirements:
- Use 1980s commercial language and style
- Include clear hooks, benefits, and calls-to-action
- Keep scripts concise and punchy
- Use active voice and emotional appeals
- End with a memorable tagline

Format your response as:
HOOK: [Opening attention-grabber]
FLOW: [Main content following the structure]
BENEFIT: [Key value proposition]
CTA: [Call to action with tagline]

Keep each segment under 2-3 sentences. Use enthusiastic, confident language typical of 1980s advertising."""
    
    def _create_user_prompt(self, brand: str, structure: str, script_prompt: str, 
                          duration: int, voice_style: str) -> str:
        """Create user prompt with specific requirements."""
        return f"""Create a {duration}-second retro commercial script for {brand}.

Structure: {structure}
Script idea: {script_prompt}
Voice style: {voice_style}

Make it authentic to 1980s TV commercials with the energy and style of that era."""
    
    def _parse_script_response(self, response: str) -> List[ScriptSegment]:
        """Parse LLM response into structured script segments."""
        segments = []
        
        # Split by segment markers
        import re
        parts = re.split(r'(HOOK:|FLOW:|BENEFIT:|CTA:)', response)
        
        for i in range(1, len(parts), 2):
            if i + 1 < len(parts):
                segment_type = parts[i].rstrip(':').lower()
                text = parts[i + 1].strip()
                
                if text:
                    # Estimate duration based on word count (150 WPM)
                    word_count = len(text.split())
                    duration = (word_count / 150) * 60  # Convert to seconds
                    
                    segments.append(ScriptSegment(
                        text=text,
                        duration_estimate=duration,
                        segment_type=segment_type,
                        timing_marker=f"[{segment_type.upper()}]"
                    ))
        
        return segments
    
    def _extract_tagline(self, response: str) -> str:
        """Extract tagline from the script response."""
        # Look for tagline in CTA section
        import re
        cta_match = re.search(r'CTA:.*?([A-Z][^.!?]*[.!?])', response, re.DOTALL)
        if cta_match:
            cta_text = cta_match.group(1)
            # Extract the last sentence as potential tagline
            sentences = re.split(r'[.!?]+', cta_text)
            if sentences:
                tagline = sentences[-1].strip()
                if len(tagline) > 5:  # Ensure it's substantial
                    return tagline
        
        # Fallback to predefined taglines
        return random.choice(TAGLINES)
    
    def generate_script_with_llm(self, brand: str, structure: str, script_prompt: str,
                                duration: int, voice_style: str, seed: int = 42) -> GeneratedScript:
        """Generate script using LLM."""
        if not self.llm_available:
            raise RuntimeError("LLM not available")
        
        # Set random seed for reproducibility
        random.seed(seed)
        
        # Create prompts
        system_prompt = self._create_system_prompt()
        user_prompt = self._create_user_prompt(brand, structure, script_prompt, duration, voice_style)
        
        # Format for the model
        if "dialo" in self.model_name.lower():
            # DialoGPT format
            text = f"{user_prompt}\n\nResponse:"
        else:
            # Generic format
            text = f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
        
        # Tokenize
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        
        # Move inputs to same device as model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate
        self.model.eval()
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.model_config.get("max_tokens", 256),
            temperature=self.model_config.get("temperature", 0.7),
            top_p=self.model_config.get("top_p", 0.9),
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            num_return_sequences=1
        )
        
        # Decode response
        response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        
        logger.info(f"Generated script response: {response[:200]}...")
        
        # Parse response
        segments = self._parse_script_response(response)
        tagline = self._extract_tagline(response)
        
        # Calculate total duration
        total_duration = sum(segment.duration_estimate for segment in segments)
        
        # Calculate word count
        word_count = sum(len(segment.text.split()) for segment in segments)
        
        return GeneratedScript(
            segments=segments,
            total_duration=total_duration,
            tagline=tagline,
            voice_style=voice_style,
            word_count=word_count,
            raw_script=response
        )
    
    def generate_script_with_template(self, brand: str, structure: str, script_prompt: str,
                                    duration: int, voice_style: str, seed: int = 42) -> GeneratedScript:
        """Generate script using template-based approach (fallback)."""
        random.seed(seed)
        
        # Select structure template
        structure_template = structure.strip() or random.choice(STRUCTURE_TEMPLATES)
        
        # Generate segments based on template
        segments = []
        
        # Hook
        hook_text = script_prompt or f"Introducing {brand} - the future is here!"
        segments.append(ScriptSegment(
            text=hook_text,
            duration_estimate=2.0,
            segment_type="hook",
            timing_marker="[HOOK]"
        ))
        
        # Flow (based on structure)
        flow_text = f"With {structure_template.lower()}, {brand} delivers results like never before."
        segments.append(ScriptSegment(
            text=flow_text,
            duration_estimate=3.0,
            segment_type="flow",
            timing_marker="[FLOW]"
        ))
        
        # Benefit
        benefit_text = "Faster, simpler, cooler - just like your favorite retro tech."
        segments.append(ScriptSegment(
            text=benefit_text,
            duration_estimate=2.5,
            segment_type="benefit",
            timing_marker="[BENEFIT]"
        ))
        
        # CTA
        tagline = random.choice(TAGLINES)
        cta_text = f"Try {brand} today. {tagline}"
        segments.append(ScriptSegment(
            text=cta_text,
            duration_estimate=2.5,
            segment_type="cta",
            timing_marker="[CTA]"
        ))
        
        # Calculate totals
        total_duration = sum(segment.duration_estimate for segment in segments)
        word_count = sum(len(segment.text.split()) for segment in segments)
        
        return GeneratedScript(
            segments=segments,
            total_duration=total_duration,
            tagline=tagline,
            voice_style=voice_style,
            word_count=word_count,
            raw_script=f"Template-based script for {brand}"
        )
    
    def generate_script(self, brand: str, structure: str, script_prompt: str,
                       duration: int, voice_style: str, seed: int = 42) -> GeneratedScript:
        """
        Generate a complete commercial script.
        """
        try:
            if self.llm_available:
                return self.generate_script_with_llm(brand, structure, script_prompt, duration, voice_style, seed)
            else:
                logger.info("Using template-based script generation (LLM not available)")
                return self.generate_script_with_template(brand, structure, script_prompt, duration, voice_style, seed)
        except Exception as e:
            logger.error(f"Script generation failed: {e}")
            logger.info("Falling back to template-based generation")
            return self.generate_script_with_template(brand, structure, script_prompt, duration, voice_style, seed)
    
    def suggest_scripts(self, structure: str, n: int = 6, seed: int = 0) -> List[str]:
        """
        Generate multiple script suggestions based on structure.
        """
        try:
            suggestions = []
            for i in range(n):
                script = self.generate_script(
                    brand="YourBrand",
                    structure=structure,
                    script_prompt="Create an engaging hook",
                    duration=10,
                    voice_style="Announcer '80s",
                    seed=seed + i
                )
                
                # Extract hook from first segment
                if script.segments:
                    hook = script.segments[0].text
                    suggestions.append(hook)
                else:
                    suggestions.append("Back to '87 - the future is now!")
            
            return suggestions
            
        except Exception as e:
            logger.warning(f"Script suggestion failed: {e}")
            # Fallback to original random generation
            return self._fallback_suggestions(structure, n, seed)
    
    def _fallback_suggestions(self, structure: str, n: int, seed: int) -> List[str]:
        """Fallback to original random script generation."""
        random.seed(seed)
        
        base = (structure or "").lower().strip()
        ideas = []
        
        for _ in range(n):
            style = random.choice(["infomercial", "mall ad", "late-night", "newsflash", "arcade bumper"])
            shot = random.choice(["neon grid", "CRT scanlines", "vaporwave sunset", "shopping mall", "boombox close-up"])
            hook = random.choice([
                "Remember this sound?", "Back to '87.", "Deal of the decade.", 
                "We paused time.", "Be kind, rewind your brand."
            ])
            idea = f"{hook} {style} with {shot}."
            
            # Light correlation with structure
            for kw in ["montage", "testimonial", "news", "unboxing", "before", "after", "countdown", "logo", "cta"]:
                if kw in base and kw not in idea:
                    idea += f" Includes {kw}."
            
            ideas.append(idea)
        
        return ideas

def create_script_generator() -> LLMScriptGenerator:
    """Factory function to create a script generator."""
    return LLMScriptGenerator()