atlas / models.py
ANISA09's picture
Rename model.py to models.py
8479e9f verified
# models.py
import logging
import traceback
from io import BytesIO
from typing import Optional, List, Dict, Any
from PIL import Image
from transformers import pipeline, AutoConfig, AutoModelForImageClassification, AutoImageProcessor
import config
logger = logging.getLogger("newsorchestra.models")
# Pipelines (lazy-initialized)
_summarizer = None
_zero_shot = None
_img_caption = None
_image_classifier = None
_deepfake_detector = None
# locks not used here for brevity, but add threading.Lock if concurrent init is expected
def init_summarizer():
global _summarizer
if _summarizer is None:
try:
_summarizer = pipeline("summarization", model=config.HF_SUMMARIZER, truncation=True)
logger.info("Loaded summarizer pipeline")
except Exception as e:
logger.exception("Could not load summarizer: %s", e)
_summarizer = None
return _summarizer
def init_zero_shot():
global _zero_shot
if _zero_shot is None:
try:
_zero_shot = pipeline("zero-shot-classification", model=config.HF_ZERO_SHOT)
logger.info("Loaded zero-shot pipeline")
except Exception as e:
logger.exception("Could not load zero-shot pipeline: %s", e)
_zero_shot = None
return _zero_shot
def init_img_caption():
global _img_caption
if _img_caption is None:
try:
try:
_img_caption = pipeline("image-to-text", model=config.HF_IMAGE_CAPTION)
except Exception:
_img_caption = pipeline("image-captioning", model=config.HF_IMAGE_CAPTION)
logger.info("Loaded image caption pipeline")
except Exception as e:
logger.exception("Image caption pipeline unavailable: %s", e)
_img_caption = None
return _img_caption
def init_image_classifier():
global _image_classifier
if _image_classifier is None:
try:
_image_classifier = pipeline("image-classification", model=config.HF_IMAGE_CLASSIFIER)
logger.info("Loaded image-classification pipeline: %s", config.HF_IMAGE_CLASSIFIER)
except Exception as e:
logger.exception("Image-classifier unavailable at startup: %s", e)
_image_classifier = None
return _image_classifier
def init_deepfake_detector():
global _deepfake_detector
if _deepfake_detector is None:
try:
_deepfake_detector = pipeline("image-classification", model=config.HF_DEEPFAKE_MODEL)
logger.info("Loaded deepfake detector pipeline: %s", config.HF_DEEPFAKE_MODEL)
except Exception as e:
logger.exception("Deepfake detector pipeline not loaded at startup: %s", e)
_deepfake_detector = None
return _deepfake_detector
# Text helpers
def hf_zero_shot(claim: str) -> Dict[str, Any]:
zs = init_zero_shot()
if not zs:
return {"error": "zero-shot pipeline not available"}
if not claim or not claim.strip():
return {"sequence": "", "labels": config.CANDIDATE_LABELS, "scores": [0.0]*len(config.CANDIDATE_LABELS), "note": "No claim text; skipped"}
try:
res = zs(claim, candidate_labels=config.CANDIDATE_LABELS, multi_label=False)
return dict(res)
except Exception as e:
logger.exception("zero-shot failed: %s", e)
return {"error": str(e), "trace": traceback.format_exc()}
def hf_image_caption(img: Image.Image) -> Optional[str]:
ic = init_img_caption()
if not ic:
return None
try:
out = ic(img)
if isinstance(out, list) and out:
first = out[0]
if isinstance(first, dict):
return first.get("generated_text") or first.get("caption") or str(first)
return str(first)
return str(out)
except Exception:
logger.exception("image_captioning failed")
return None
def _ensure_pil(img_input) -> Optional[Image.Image]:
"""Accept PIL.Image, bytes, bytearray, file-like, or path. Returns PIL.Image or None."""
if img_input is None:
return None
try:
if isinstance(img_input, Image.Image):
return img_input.convert("RGB")
if isinstance(img_input, (bytes, bytearray)):
return Image.open(BytesIO(img_input)).convert("RGB")
# file-like: has read()
if hasattr(img_input, "read"):
return Image.open(img_input).convert("RGB")
# path string β€” let PIL open
if isinstance(img_input, str):
return Image.open(img_input).convert("RGB")
except Exception as e:
logger.exception("Could not convert input to PIL.Image: %s", e)
return None
def hf_image_classify(img_input, top_k: int = 3) -> List[dict]:
"""
Robust wrapper: accepts PIL.Image, bytes, file-like or path.
Returns list[{"label": str, "score": float}]
"""
results = []
try:
img = _ensure_pil(img_input)
if img is None:
logger.warning("hf_image_classify: input could not be made into PIL.Image")
return results
classifier = init_image_classifier()
if classifier is None:
logger.warning("hf_image_classify: classifier pipeline unavailable")
return results
try:
out = classifier(img, top_k=top_k)
except TypeError:
# older/newer pipeline signature differences
out = classifier(img, top_k=top_k) # try again (kept for clarity)
if isinstance(out, list):
for r in out:
if isinstance(r, dict):
label = r.get("label")
score = float(r.get("score", 0.0))
# sometimes labels come as 'LABEL_1' β€” that's fine but attempt to map if model config present
results.append({"label": str(label), "score": score})
else:
results.append({"label": str(r), "score": None})
else:
results.append({"label": str(out), "score": None})
except Exception:
logger.exception("hf_image_classify unexpected error")
return results
def hf_deepfake_check(img_input, top_k: int = 3) -> List[dict]:
results = []
try:
img = _ensure_pil(img_input)
if img is None:
logger.warning("hf_deepfake_check: input could not be made into PIL.Image")
return results
detector = init_deepfake_detector()
if detector is None:
logger.warning("hf_deepfake_check: deepfake pipeline unavailable")
return results
out = detector(img, top_k=top_k)
if isinstance(out, list):
for r in out:
if isinstance(r, dict):
results.append({"label": str(r.get("label")), "score": float(r.get("score", 0.0))})
else:
results.append({"label": str(r), "score": None})
else:
results.append({"label": str(out), "score": None})
except Exception:
logger.exception("hf_deepfake_check unexpected error")
return results