| import logging |
| from typing import Union, List, Optional |
|
|
| import torch |
| from PIL import Image |
|
|
| logger = logging.getLogger(__name__) |
|
|
| T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. |
| Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. |
| Start directly with the action, and keep descriptions literal and precise. |
| Think like a cinematographer describing a shot list. |
| Do not change the user input intent, just enhance it. |
| Keep within 150 words. |
| For best results, build your prompts using this structure: |
| Start with main action in a single sentence |
| Add specific details about movements and gestures |
| Describe character/object appearances precisely |
| Include background and environment details |
| Specify camera angles and movements |
| Describe lighting and colors |
| Note any changes or sudden events |
| Do not exceed the 150 word limit! |
| Output the enhanced prompt only. |
| """ |
|
|
| I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. |
| Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. |
| Start directly with the action, and keep descriptions literal and precise. |
| Think like a cinematographer describing a shot list. |
| Keep within 150 words. |
| For best results, build your prompts using this structure: |
| Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. |
| Start with main action in a single sentence |
| Add specific details about movements and gestures |
| Describe character/object appearances precisely |
| Include background and environment details |
| Specify camera angles and movements |
| Describe lighting and colors |
| Note any changes or sudden events |
| Align to the image caption if it contradicts the user text input. |
| Do not exceed the 150 word limit! |
| Output the enhanced prompt only. |
| """ |
|
|
|
|
| def tensor_to_pil(tensor): |
| |
| assert tensor.min() >= -1 and tensor.max() <= 1 |
|
|
| |
| tensor = (tensor + 1) / 2 |
|
|
| |
| tensor = tensor.permute(1, 2, 0) |
|
|
| |
| numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") |
|
|
| |
| return Image.fromarray(numpy_image) |
|
|
|
|
| def generate_cinematic_prompt( |
| image_caption_model, |
| image_caption_processor, |
| prompt_enhancer_model, |
| prompt_enhancer_tokenizer, |
| prompt: Union[str, List[str]], |
| conditioning_items: Optional[List] = None, |
| max_new_tokens: int = 256, |
| ) -> List[str]: |
| prompts = [prompt] if isinstance(prompt, str) else prompt |
|
|
| if conditioning_items is None: |
| prompts = _generate_t2v_prompt( |
| prompt_enhancer_model, |
| prompt_enhancer_tokenizer, |
| prompts, |
| max_new_tokens, |
| T2V_CINEMATIC_PROMPT, |
| ) |
| else: |
| if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: |
| logger.warning( |
| "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" |
| ) |
| return prompts |
|
|
| first_frame_conditioning_item = conditioning_items[0] |
| first_frames = _get_first_frames_from_conditioning_item( |
| first_frame_conditioning_item |
| ) |
|
|
| assert len(first_frames) == len( |
| prompts |
| ), "Number of conditioning frames must match number of prompts" |
|
|
| prompts = _generate_i2v_prompt( |
| image_caption_model, |
| image_caption_processor, |
| prompt_enhancer_model, |
| prompt_enhancer_tokenizer, |
| prompts, |
| first_frames, |
| max_new_tokens, |
| I2V_CINEMATIC_PROMPT, |
| ) |
|
|
| return prompts |
|
|
|
|
| def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: |
| frames_tensor = conditioning_item.media_item |
| return [ |
| tensor_to_pil(frames_tensor[i, :, 0, :, :]) |
| for i in range(frames_tensor.shape[0]) |
| ] |
|
|
|
|
| def _generate_t2v_prompt( |
| prompt_enhancer_model, |
| prompt_enhancer_tokenizer, |
| prompts: List[str], |
| max_new_tokens: int, |
| system_prompt: str, |
| ) -> List[str]: |
| messages = [ |
| [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": f"user_prompt: {p}"}, |
| ] |
| for p in prompts |
| ] |
|
|
| texts = [ |
| prompt_enhancer_tokenizer.apply_chat_template( |
| m, tokenize=False, add_generation_prompt=True |
| ) |
| for m in messages |
| ] |
| model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( |
| prompt_enhancer_model.device |
| ) |
|
|
| return _generate_and_decode_prompts( |
| prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens |
| ) |
|
|
|
|
| def _generate_i2v_prompt( |
| image_caption_model, |
| image_caption_processor, |
| prompt_enhancer_model, |
| prompt_enhancer_tokenizer, |
| prompts: List[str], |
| first_frames: List[Image.Image], |
| max_new_tokens: int, |
| system_prompt: str, |
| ) -> List[str]: |
| image_captions = _generate_image_captions( |
| image_caption_model, image_caption_processor, first_frames |
| ) |
|
|
| messages = [ |
| [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, |
| ] |
| for p, c in zip(prompts, image_captions) |
| ] |
|
|
| texts = [ |
| prompt_enhancer_tokenizer.apply_chat_template( |
| m, tokenize=False, add_generation_prompt=True |
| ) |
| for m in messages |
| ] |
| model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( |
| prompt_enhancer_model.device |
| ) |
|
|
| return _generate_and_decode_prompts( |
| prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens |
| ) |
|
|
|
|
| def _generate_image_captions( |
| image_caption_model, |
| image_caption_processor, |
| images: List[Image.Image], |
| system_prompt: str = "<DETAILED_CAPTION>", |
| ) -> List[str]: |
| image_caption_prompts = [system_prompt] * len(images) |
| inputs = image_caption_processor( |
| image_caption_prompts, images, return_tensors="pt" |
| ).to(image_caption_model.device) |
|
|
| with torch.inference_mode(): |
| generated_ids = image_caption_model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| do_sample=False, |
| num_beams=3, |
| ) |
|
|
| return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
| def _generate_and_decode_prompts( |
| prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int |
| ) -> List[str]: |
| with torch.inference_mode(): |
| outputs = prompt_enhancer_model.generate( |
| **model_inputs, max_new_tokens=max_new_tokens |
| ) |
| generated_ids = [ |
| output_ids[len(input_ids) :] |
| for input_ids, output_ids in zip(model_inputs.input_ids, outputs) |
| ] |
| decoded_prompts = prompt_enhancer_tokenizer.batch_decode( |
| generated_ids, skip_special_tokens=True |
| ) |
|
|
| return decoded_prompts |
|
|