|
|
|
|
|
import logging |
|
|
import traceback |
|
|
from io import BytesIO |
|
|
from typing import Optional, List, Dict, Any |
|
|
from PIL import Image |
|
|
|
|
|
from transformers import pipeline, AutoConfig, AutoModelForImageClassification, AutoImageProcessor |
|
|
|
|
|
import config |
|
|
|
|
|
logger = logging.getLogger("newsorchestra.models") |
|
|
|
|
|
|
|
|
_summarizer = None |
|
|
_zero_shot = None |
|
|
_img_caption = None |
|
|
_image_classifier = None |
|
|
_deepfake_detector = None |
|
|
|
|
|
|
|
|
|
|
|
def init_summarizer(): |
|
|
global _summarizer |
|
|
if _summarizer is None: |
|
|
try: |
|
|
_summarizer = pipeline("summarization", model=config.HF_SUMMARIZER, truncation=True) |
|
|
logger.info("Loaded summarizer pipeline") |
|
|
except Exception as e: |
|
|
logger.exception("Could not load summarizer: %s", e) |
|
|
_summarizer = None |
|
|
return _summarizer |
|
|
|
|
|
def init_zero_shot(): |
|
|
global _zero_shot |
|
|
if _zero_shot is None: |
|
|
try: |
|
|
_zero_shot = pipeline("zero-shot-classification", model=config.HF_ZERO_SHOT) |
|
|
logger.info("Loaded zero-shot pipeline") |
|
|
except Exception as e: |
|
|
logger.exception("Could not load zero-shot pipeline: %s", e) |
|
|
_zero_shot = None |
|
|
return _zero_shot |
|
|
|
|
|
def init_img_caption(): |
|
|
global _img_caption |
|
|
if _img_caption is None: |
|
|
try: |
|
|
try: |
|
|
_img_caption = pipeline("image-to-text", model=config.HF_IMAGE_CAPTION) |
|
|
except Exception: |
|
|
_img_caption = pipeline("image-captioning", model=config.HF_IMAGE_CAPTION) |
|
|
logger.info("Loaded image caption pipeline") |
|
|
except Exception as e: |
|
|
logger.exception("Image caption pipeline unavailable: %s", e) |
|
|
_img_caption = None |
|
|
return _img_caption |
|
|
|
|
|
def init_image_classifier(): |
|
|
global _image_classifier |
|
|
if _image_classifier is None: |
|
|
try: |
|
|
_image_classifier = pipeline("image-classification", model=config.HF_IMAGE_CLASSIFIER) |
|
|
logger.info("Loaded image-classification pipeline: %s", config.HF_IMAGE_CLASSIFIER) |
|
|
except Exception as e: |
|
|
logger.exception("Image-classifier unavailable at startup: %s", e) |
|
|
_image_classifier = None |
|
|
return _image_classifier |
|
|
|
|
|
def init_deepfake_detector(): |
|
|
global _deepfake_detector |
|
|
if _deepfake_detector is None: |
|
|
try: |
|
|
_deepfake_detector = pipeline("image-classification", model=config.HF_DEEPFAKE_MODEL) |
|
|
logger.info("Loaded deepfake detector pipeline: %s", config.HF_DEEPFAKE_MODEL) |
|
|
except Exception as e: |
|
|
logger.exception("Deepfake detector pipeline not loaded at startup: %s", e) |
|
|
_deepfake_detector = None |
|
|
return _deepfake_detector |
|
|
|
|
|
|
|
|
def hf_zero_shot(claim: str) -> Dict[str, Any]: |
|
|
zs = init_zero_shot() |
|
|
if not zs: |
|
|
return {"error": "zero-shot pipeline not available"} |
|
|
if not claim or not claim.strip(): |
|
|
return {"sequence": "", "labels": config.CANDIDATE_LABELS, "scores": [0.0]*len(config.CANDIDATE_LABELS), "note": "No claim text; skipped"} |
|
|
try: |
|
|
res = zs(claim, candidate_labels=config.CANDIDATE_LABELS, multi_label=False) |
|
|
return dict(res) |
|
|
except Exception as e: |
|
|
logger.exception("zero-shot failed: %s", e) |
|
|
return {"error": str(e), "trace": traceback.format_exc()} |
|
|
|
|
|
def hf_image_caption(img: Image.Image) -> Optional[str]: |
|
|
ic = init_img_caption() |
|
|
if not ic: |
|
|
return None |
|
|
try: |
|
|
out = ic(img) |
|
|
if isinstance(out, list) and out: |
|
|
first = out[0] |
|
|
if isinstance(first, dict): |
|
|
return first.get("generated_text") or first.get("caption") or str(first) |
|
|
return str(first) |
|
|
return str(out) |
|
|
except Exception: |
|
|
logger.exception("image_captioning failed") |
|
|
return None |
|
|
|
|
|
def _ensure_pil(img_input) -> Optional[Image.Image]: |
|
|
"""Accept PIL.Image, bytes, bytearray, file-like, or path. Returns PIL.Image or None.""" |
|
|
if img_input is None: |
|
|
return None |
|
|
try: |
|
|
if isinstance(img_input, Image.Image): |
|
|
return img_input.convert("RGB") |
|
|
if isinstance(img_input, (bytes, bytearray)): |
|
|
return Image.open(BytesIO(img_input)).convert("RGB") |
|
|
|
|
|
if hasattr(img_input, "read"): |
|
|
return Image.open(img_input).convert("RGB") |
|
|
|
|
|
if isinstance(img_input, str): |
|
|
return Image.open(img_input).convert("RGB") |
|
|
except Exception as e: |
|
|
logger.exception("Could not convert input to PIL.Image: %s", e) |
|
|
return None |
|
|
|
|
|
def hf_image_classify(img_input, top_k: int = 3) -> List[dict]: |
|
|
""" |
|
|
Robust wrapper: accepts PIL.Image, bytes, file-like or path. |
|
|
Returns list[{"label": str, "score": float}] |
|
|
""" |
|
|
results = [] |
|
|
try: |
|
|
img = _ensure_pil(img_input) |
|
|
if img is None: |
|
|
logger.warning("hf_image_classify: input could not be made into PIL.Image") |
|
|
return results |
|
|
|
|
|
classifier = init_image_classifier() |
|
|
if classifier is None: |
|
|
logger.warning("hf_image_classify: classifier pipeline unavailable") |
|
|
return results |
|
|
|
|
|
try: |
|
|
out = classifier(img, top_k=top_k) |
|
|
except TypeError: |
|
|
|
|
|
out = classifier(img, top_k=top_k) |
|
|
if isinstance(out, list): |
|
|
for r in out: |
|
|
if isinstance(r, dict): |
|
|
label = r.get("label") |
|
|
score = float(r.get("score", 0.0)) |
|
|
|
|
|
results.append({"label": str(label), "score": score}) |
|
|
else: |
|
|
results.append({"label": str(r), "score": None}) |
|
|
else: |
|
|
results.append({"label": str(out), "score": None}) |
|
|
except Exception: |
|
|
logger.exception("hf_image_classify unexpected error") |
|
|
return results |
|
|
|
|
|
def hf_deepfake_check(img_input, top_k: int = 3) -> List[dict]: |
|
|
results = [] |
|
|
try: |
|
|
img = _ensure_pil(img_input) |
|
|
if img is None: |
|
|
logger.warning("hf_deepfake_check: input could not be made into PIL.Image") |
|
|
return results |
|
|
|
|
|
detector = init_deepfake_detector() |
|
|
if detector is None: |
|
|
logger.warning("hf_deepfake_check: deepfake pipeline unavailable") |
|
|
return results |
|
|
|
|
|
out = detector(img, top_k=top_k) |
|
|
if isinstance(out, list): |
|
|
for r in out: |
|
|
if isinstance(r, dict): |
|
|
results.append({"label": str(r.get("label")), "score": float(r.get("score", 0.0))}) |
|
|
else: |
|
|
results.append({"label": str(r), "score": None}) |
|
|
else: |
|
|
results.append({"label": str(out), "score": None}) |
|
|
except Exception: |
|
|
logger.exception("hf_deepfake_check unexpected error") |
|
|
return results |
|
|
|