| import re |
| import torch |
| import clip |
| import numpy as np |
| from numpy.linalg import norm |
| from PIL import Image |
|
|
| def get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints): |
| hint = [] |
| if mos > 5: |
| mos = (mos / 100) * 5 |
| if mos >= 4.5: |
| hint.append(quality_hints["mos"]["excellent"]) |
| elif 3.5 <= mos < 4.5: |
| hint.append(quality_hints["mos"]["good"]) |
| elif 2.5 <= mos < 3.5: |
| hint.append(quality_hints["mos"]["fair"]) |
| elif 1.5 <= mos < 2.5: |
| hint.append(quality_hints["mos"]["bad"]) |
| else: |
| hint.append(quality_hints["mos"]["poor"]) |
|
|
| res = width * height |
| if res < 640 * 480: |
| hint.append(quality_hints["resolution"]["low"]) |
| elif res < 1280 * 720: |
| hint.append(quality_hints["resolution"]["sd"]) |
| else: |
| hint.append(quality_hints["resolution"]["hd"]) |
| if bitrate < 500_000: |
| hint.append(quality_hints["bitrate"]["low"]) |
| elif bitrate < 1_000_000: |
| hint.append(quality_hints["bitrate"]["medium"]) |
| else: |
| hint.append(quality_hints["bitrate"]["high"]) |
|
|
| if 0 < bitdepth <= 8: |
| hint.append(quality_hints["bitdepth"]["low"]) |
| elif bitdepth == 0: |
| hint.append(quality_hints["bitdepth"]["standard"]) |
| else: |
| hint.append(quality_hints["bitdepth"]["high"]) |
| if framerate < 24: |
| hint.append(quality_hints["framerate"]["low"]) |
| elif framerate > 60: |
| hint.append(quality_hints["framerate"]["high"]) |
| else: |
| hint.append(quality_hints["framerate"]["standard"]) |
| return " ".join(hint) |
|
|
| def generate_caption(blip_processor, blip_model, device, image, prompt): |
| inputs = blip_processor(image, prompt, return_tensors="pt").to(device) |
| generated_ids = blip_model.generate(**inputs, max_new_tokens=50) |
| caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True) |
| return caption |
|
|
| def tensor_to_pil(image_tensor): |
| if isinstance(image_tensor, torch.Tensor): |
| arr = image_tensor.cpu().numpy() |
| if arr.ndim == 4 and arr.shape[0] == 1: |
| arr = arr[0] |
| arr = arr.astype('uint8') |
| return Image.fromarray(arr) |
|
|
| def extract_semantic_captions(blip_processor, blip_model, curr_frame, frag_residual, frag_frame, prompts, device, metadata=None, use_metadata_prompt=False): |
| quality_prompt_base = prompts["quality_prompt_base"] |
| residual_prompt = prompts["residual_prompt"] |
| frag_prompt = prompts["frag_prompt"] |
|
|
| quality_hint = "" |
| if use_metadata_prompt and metadata: |
| mos, width, height, bitrate, bitdepth, framerate = metadata |
| quality_hint = get_quality_hint_from_metadata(mos, width, height, bitrate, bitdepth, framerate, quality_hints=prompts["quality_hints"]) |
|
|
| prompt_hints = [] |
| if quality_hint: |
| prompt_hints.append(quality_hint) |
|
|
| quality_prompt = "\n\n".join(prompt_hints + [quality_prompt_base]) |
| fragment_prompt = "\n\n".join(prompt_hints) |
| |
| |
| |
| |
|
|
| captions = { |
| "curr_frame_quality": generate_caption(blip_processor, blip_model, device, curr_frame, prompt=quality_prompt), |
| "frag_residual": generate_caption(blip_processor, blip_model, device, frag_residual, prompt=(fragment_prompt + "\n\n" + residual_prompt)), |
| "frag_frame": generate_caption(blip_processor, blip_model, device, frag_frame, prompt=(fragment_prompt + "\n\n" + frag_prompt)) |
| } |
| return captions |
|
|
| def clean_caption_text(text): |
| text = re.sub(r"- .*?stock videos & royalty-free footage", "", text) |
| text = re.sub(r"\s+", " ", text) |
| return text.strip() |
|
|
| def dedup_keywords(text, split_tokens=[",", ".", ";"]): |
| for token in split_tokens: |
| text = text.replace(token, ",") |
| parts = [p.strip().lower() for p in text.split(",") if p.strip()] |
| seen = set() |
| unique_parts = [] |
| for part in parts: |
| if part not in seen: |
| unique_parts.append(part) |
| seen.add(part) |
| return " ".join(unique_parts) |
|
|
| def get_clip_text_embedding(clip_model, device, text): |
| text_tokens = clip.tokenize([text]).to(device) |
| with torch.no_grad(): |
| with torch.amp.autocast(device_type='cuda'): |
| text_features = clip_model.encode_text(text_tokens) |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| return text_features.squeeze() |
|
|
| def get_clip_image_embedding(clip_model, clip_preprocess, device, image): |
| image_input = clip_preprocess(image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| with torch.amp.autocast(device_type='cuda'): |
| image_features = clip_model.encode_image(image_input) |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
| return image_features.squeeze() |
|
|
| def extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions): |
| if not isinstance(curr_frame, Image.Image): |
| curr_frame = Image.fromarray(curr_frame) |
|
|
| quality_caption = dedup_keywords(clean_caption_text(captions["curr_frame_quality"])) |
| artifact_caption_1 = dedup_keywords(clean_caption_text(captions["frag_residual"])) |
| artifact_caption_2 = dedup_keywords(clean_caption_text(captions["frag_frame"])) |
| artifact_caption = dedup_keywords(f"{artifact_caption_1}, {artifact_caption_2}") |
|
|
| image_embed = get_clip_image_embedding(clip_model, clip_preprocess, device, curr_frame) |
| quality_embed = get_clip_text_embedding(clip_model, device, quality_caption) |
| artifact_embed = get_clip_text_embedding(clip_model, device, artifact_caption) |
| return image_embed, quality_embed, artifact_embed |
|
|
| def extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device): |
| feature_image_embed = [] |
| feature_quality_embed = [] |
| feature_artifact_embed = [] |
| for i, (curr_frame, frag_residual, frag_frame) in enumerate(frames_info): |
| curr_frame = tensor_to_pil(curr_frame) |
| frag_residual = tensor_to_pil(frag_residual) |
| frag_frame = tensor_to_pil(frag_frame) |
|
|
| captions = extract_semantic_captions( |
| blip_processor, blip_model, |
| curr_frame, frag_residual, frag_frame, prompts, |
| device, |
| metadata=metadata, |
| use_metadata_prompt=True, |
| ) |
| image_embed, quality_embed, artifact_embed = extract_semantic_embeddings(clip_model, clip_preprocess, device, curr_frame, captions) |
| feature_image_embed.append(image_embed) |
| feature_quality_embed.append(quality_embed) |
| feature_artifact_embed.append(artifact_embed) |
|
|
| |
| image_embedding = torch.stack(feature_image_embed, dim=0) |
| quality_embedding = torch.stack(feature_quality_embed, dim=0) |
| artifact_embedding = torch.stack(feature_artifact_embed, dim=0) |
| |
| return image_embedding, quality_embedding, artifact_embedding |
|
|