atlas / media_utils.py
ANISA09's picture
Create media_utils.py
fccd019 verified
# media_utils.py
import logging
import re
from io import BytesIO
from typing import Optional, Tuple
from urllib.parse import urlparse
import requests
from PIL import Image, ExifTags
# optional OCR and HF pipelines
try:
import pytesseract
except Exception:
pytesseract = None
from transformers import pipeline
import os
logger = logging.getLogger("media_utils")
HF_IMAGE_CAPTION = os.getenv("HF_IMAGE_CAPTION", "nlpconnect/vit-gpt2-image-captioning")
HF_IMAGE_CLASSIFIER = os.getenv("HF_IMAGE_CLASSIFIER", "google/vit-base-patch16-224")
# load models best-effort
img_caption = None
image_classifier = None
try:
img_caption = pipeline("image-to-text", model=HF_IMAGE_CAPTION)
logger.info("Loaded image caption pipeline")
except Exception:
try:
img_caption = pipeline("image-captioning", model=HF_IMAGE_CAPTION)
logger.info("Loaded image caption pipeline (fallback name)")
except Exception as e:
logger.warning("Image caption pipeline unavailable: %s", e)
img_caption = None
try:
image_classifier = pipeline("image-classification", model=HF_IMAGE_CLASSIFIER)
logger.info("Loaded image-classification pipeline")
except Exception as e:
logger.warning("Image classifier unavailable: %s", e)
image_classifier = None
def fetch_image_bytes(url: str, timeout: int = 12) -> Tuple[Optional[Image.Image], Optional[bytes], Optional[str]]:
headers = {"User-Agent": "Mozilla/5.0", "Referer": urlparse(url).scheme + "://" + (urlparse(url).hostname or "")}
try:
r = requests.get(url, timeout=timeout, headers=headers, allow_redirects=True)
r.raise_for_status()
b = r.content
try:
img = Image.open(BytesIO(b)).convert("RGB")
return img, b, None
except Exception as e:
logger.warning("PIL open failed for %s: %s", url, e)
return None, b, f"PIL open error: {e}"
except Exception as e:
logger.error("fetch_image_bytes failed for %s: %s", url, e)
return None, None, str(e)
def extract_exif(img: Image.Image) -> dict:
out = {}
try:
raw = img._getexif()
if not raw:
return {}
for tag_id, val in raw.items():
tag = ExifTags.TAGS.get(tag_id, tag_id)
out[tag] = val
except Exception:
pass
return out
def image_ocr_text(img: Image.Image) -> Optional[str]:
if not pytesseract:
return None
try:
return pytesseract.image_to_string(img).strip()
except Exception:
return None
def hf_image_caption(img: Image.Image) -> Optional[str]:
if not img_caption:
return None
try:
out = img_caption(img)
if isinstance(out, list) and out:
first = out[0]
if isinstance(first, dict):
return first.get("generated_text") or first.get("caption") or str(first)
return str(first)
return str(out)
except Exception:
logger.exception("image_captioning failed")
return None
def hf_image_classify(img: Image.Image) -> list:
results = []
if not image_classifier or not isinstance(img, Image.Image):
logger.warning("Image classifier unavailable or invalid image")
return results
try:
img_resized = img.resize((224, 224))
out = image_classifier(img_resized, top_k=3)
if isinstance(out, list):
for r in out:
if isinstance(r, dict):
results.append({"label": str(r.get("label", "unknown")), "score": float(r.get("score", 0))})
else:
results.append({"label": str(r), "score": None})
except Exception as e:
logger.exception("image_classify failed: %s", e)
return results
# Video helpers (optional, lightweight)
def extract_video_keyframes(video_path: str, max_frames: int = 8) -> list:
"""
Try to extract keyframes using opencv if available.
Returns list of PIL.Image frames (may be empty if opencv not installed).
"""
try:
import cv2
except Exception:
logger.info("opencv not available; skipping video frame extraction")
return []
frames = []
try:
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
if total <= 0:
cap.release()
return frames
step = max(1, total // max_frames)
idx = 0
while len(frames) < max_frames:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if not ret:
break
# convert BGR -> RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil = Image.fromarray(frame)
frames.append(pil)
idx += step
cap.release()
except Exception:
logger.exception("extract_video_keyframes failed")
return frames