daniihc16's picture
Upload folder using huggingface_hub
c653359 verified
import io
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import torch
from datasets import DatasetDict, load_dataset
from PIL import Image, ImageDraw
from transformers import (
AutoModel,
AutoProcessor,
BlipForConditionalGeneration,
BlipForImageTextRetrieval,
BlipForQuestionAnswering,
BlipProcessor,
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LABELS = ["email", "resume", "scientific paper"]
CLIP_CHECKPOINT = "openai/clip-vit-base-patch32"
BLIP_ITM_CHECKPOINT = "Salesforce/blip-itm-base-coco"
BLIP_VQA_CHECKPOINT = "Salesforce/blip-vqa-base"
BLIP_CAPTION_CHECKPOINT = "Salesforce/blip-image-captioning-base"
CLIP_PROMPT_BANK = {
"email": [
"a scanned email with from to subject headers",
"a screenshot of an email message in an inbox",
"an email document with sender recipient and subject lines",
"a digital mail message with formal header fields",
"an office email printout",
],
"resume": [
"a professional resume document",
"a curriculum vitae with work experience and education",
"a CV page listing skills and profile",
"a job application resume in document format",
"a candidate resume with sections for experience and education",
],
"scientific paper": [
"a scientific paper with abstract and references",
"a research article with sections and citations",
"an academic publication page with dense text",
"a journal paper with methodology and results",
"a scholarly scientific document",
],
}
BLIP_ITM_PROMPT_BANK = {
"email": [
"a scanned email document",
"an email message page with sender recipient and subject",
"a digital email printout",
],
"resume": [
"a professional resume document",
"a curriculum vitae with skills and work experience",
"a CV page for a job application",
],
"scientific paper": [
"a scientific paper page",
"an academic research article with abstract and references",
"a scholarly journal publication",
],
}
BLIP_SYNONYMS = {
"email": [
"email",
"e mail",
"mail",
"inbox",
"subject",
"sender",
"recipient",
"message",
],
"resume": [
"resume",
"curriculum vitae",
"curriculum",
"cv",
"work experience",
"skills",
"education",
"candidate",
],
"scientific paper": [
"scientific paper",
"research paper",
"journal article",
"academic paper",
"scientific publication",
"abstract",
"references",
"methodology",
"introduction",
],
}
BLIP_VQA_PROMPTS = [
"Question: Is this document an email, a resume, or a scientific paper? Answer with one option only. Answer:",
"Question: Choose one label for this document: email, resume, scientific paper. Answer:",
"Question: Document type? Options: email, resume, scientific paper. Answer with one label:",
]
DATASET_ID = "nielsr/rvl_cdip_10_examples_per_class"
PREFERRED_SPLITS = ["test", "validation", "train"]
EXAMPLE_DIR = Path(__file__).parent / "examples"
VALID_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".webp", ".tif", ".tiff"}
MAX_SCAN = 300
MAX_EXAMPLES_PER_LABEL = 2
TARGET_ALIASES = {
"email": {"email"},
"resume": {"resume", "cv"},
"scientific paper": {"scientificpublication", "scientificreport", "scientificpaper"},
}
def normalize_label(text: str) -> str:
return "".join(ch for ch in str(text).lower() if ch.isalnum())
def map_to_target_label(raw_value, label_names=None) -> Optional[str]:
value = raw_value
if label_names is not None:
try:
value = label_names[int(raw_value)]
except Exception:
value = raw_value
normalized = normalize_label(value)
for target_name, aliases in TARGET_ALIASES.items():
if normalized in aliases:
return target_name
return None
def ensure_placeholder_examples() -> List[str]:
EXAMPLE_DIR.mkdir(parents=True, exist_ok=True)
generated_paths: List[str] = []
for idx, label in enumerate(LABELS):
filename = EXAMPLE_DIR / f"placeholder_{idx}_{label.replace(' ', '_')}.png"
if not filename.exists():
img = Image.new("RGB", (900, 1200), color=(245, 245, 245))
draw = ImageDraw.Draw(img)
draw.rectangle([(40, 40), (860, 1160)], outline=(90, 90, 90), width=4)
draw.text((80, 120), f"Example placeholder: {label}", fill=(20, 20, 20))
draw.text((80, 220), "Upload a real document image for best results.", fill=(40, 40, 40))
img.save(filename)
generated_paths.append(str(filename))
return generated_paths
def ensure_real_examples() -> List[str]:
EXAMPLE_DIR.mkdir(parents=True, exist_ok=True)
existing = sorted([p for p in EXAMPLE_DIR.iterdir() if p.suffix.lower() in VALID_EXT])
if len(existing) >= 3:
return [str(p) for p in existing]
counts = {label: 0 for label in LABELS}
try:
ds_obj = load_dataset(DATASET_ID)
if isinstance(ds_obj, DatasetDict):
split_name = next((s for s in PREFERRED_SPLITS if s in ds_obj), None)
if split_name is None:
split_name = next(iter(ds_obj.keys()))
selected_ds = ds_obj[split_name]
else:
selected_ds = ds_obj
if hasattr(selected_ds, "shuffle"):
selected_ds = selected_ds.shuffle(seed=42)
label_feature = selected_ds.features.get("label") if hasattr(selected_ds, "features") else None
label_names = getattr(label_feature, "names", None)
for i, row in enumerate(selected_ds):
if i >= MAX_SCAN:
break
if all(v >= MAX_EXAMPLES_PER_LABEL for v in counts.values()):
break
target_label = None
if "label" in row:
target_label = map_to_target_label(row["label"], label_names)
if target_label is None:
for field in ["ground_truth", "category", "label_name", "class"]:
if field in row:
target_label = map_to_target_label(row[field], None)
if target_label is not None:
break
if target_label is None or counts[target_label] >= MAX_EXAMPLES_PER_LABEL:
continue
img_obj = row.get("image")
img = None
try:
if isinstance(img_obj, Image.Image):
img = img_obj.convert("RGB")
elif isinstance(img_obj, dict) and img_obj.get("bytes") is not None:
img = Image.open(io.BytesIO(img_obj["bytes"])).convert("RGB")
elif isinstance(img_obj, dict) and img_obj.get("path") is not None:
img = Image.open(img_obj["path"]).convert("RGB")
elif isinstance(img_obj, str):
img = Image.open(img_obj).convert("RGB")
except Exception:
img = None
if img is None:
continue
out_name = f"example_{target_label.replace(' ', '_')}_{counts[target_label]:02d}.png"
out_path = EXAMPLE_DIR / out_name
img.save(out_path)
counts[target_label] += 1
downloaded = sorted([p for p in EXAMPLE_DIR.iterdir() if p.suffix.lower() in VALID_EXT])
if len(downloaded) >= 3:
return [str(p) for p in downloaded]
except Exception:
pass
return ensure_placeholder_examples()
print(f"Loading CLIP on {DEVICE}...")
clip_processor = AutoProcessor.from_pretrained(CLIP_CHECKPOINT, use_fast=False)
clip_model = AutoModel.from_pretrained(CLIP_CHECKPOINT).to(DEVICE)
clip_model.eval()
print("Loading BLIP...")
blip_mode = "none"
blip_processor = None
blip_model = None
blip_load_message = "BLIP not available"
try:
blip_processor = BlipProcessor.from_pretrained(BLIP_ITM_CHECKPOINT)
blip_model = BlipForImageTextRetrieval.from_pretrained(BLIP_ITM_CHECKPOINT).to(DEVICE)
blip_model.eval()
blip_mode = "blip-itm"
blip_load_message = f"BLIP ITM loaded: {BLIP_ITM_CHECKPOINT}"
except Exception as e_itm:
try:
blip_processor = BlipProcessor.from_pretrained(BLIP_VQA_CHECKPOINT)
blip_model = BlipForQuestionAnswering.from_pretrained(BLIP_VQA_CHECKPOINT).to(DEVICE)
blip_model.eval()
blip_mode = "blip-vqa"
blip_load_message = (
f"BLIP ITM unavailable ({e_itm}). Fallback VQA loaded: {BLIP_VQA_CHECKPOINT}"
)
except Exception as e_vqa:
try:
blip_processor = BlipProcessor.from_pretrained(BLIP_CAPTION_CHECKPOINT)
blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_CAPTION_CHECKPOINT).to(DEVICE)
blip_model.eval()
blip_mode = "blip-caption"
blip_load_message = (
f"BLIP ITM/VQA unavailable ({e_itm} | {e_vqa}). Fallback caption loaded: {BLIP_CAPTION_CHECKPOINT}"
)
except Exception as e_caption:
blip_mode = "none"
blip_load_message = (
f"BLIP unavailable ({e_itm} | {e_vqa} | {e_caption}). Running only CLIP."
)
print(blip_load_message)
def normalize_text_for_match(text: str) -> str:
cleaned = "".join(ch if ch.isalnum() else " " for ch in str(text).lower())
return " ".join(cleaned.split())
def map_text_to_blip_label(text: str) -> str:
clean = normalize_text_for_match(text)
if not clean:
return "unknown"
scores = {label: 0 for label in LABELS}
for label, terms in BLIP_SYNONYMS.items():
for term in terms:
norm_term = normalize_text_for_match(term)
if norm_term and norm_term in clean:
scores[label] += 1
best_label = max(LABELS, key=lambda lbl: scores[lbl])
return best_label if scores[best_label] > 0 else "unknown"
def classify_clip(image: Image.Image) -> Tuple[str, float, float, Dict[str, float]]:
all_prompts: List[str] = []
prompt_label_idx: List[int] = []
for label_idx, label in enumerate(LABELS):
prompts = CLIP_PROMPT_BANK[label]
all_prompts.extend(prompts)
prompt_label_idx.extend([label_idx] * len(prompts))
clip_inputs = clip_processor(
text=all_prompts,
images=image,
return_tensors="pt",
padding=True,
truncation=True,
).to(DEVICE)
with torch.no_grad():
clip_outputs = clip_model(**clip_inputs)
if not hasattr(clip_outputs, "logits_per_image"):
raise RuntimeError("CLIP output does not contain logits_per_image.")
prompt_probs = clip_outputs.logits_per_image.softmax(dim=1)[0].detach().cpu()
class_probs = torch.zeros(len(LABELS), dtype=prompt_probs.dtype)
for label_idx, _ in enumerate(LABELS):
idxs = [i for i, idx in enumerate(prompt_label_idx) if idx == label_idx]
class_probs[label_idx] = prompt_probs[idxs].mean()
class_probs = class_probs / class_probs.sum()
top2_vals, top2_idx = torch.topk(class_probs, k=2)
pred_idx = int(top2_idx[0].item())
pred = LABELS[pred_idx]
confidence = float(top2_vals[0].item())
margin = float((top2_vals[0] - top2_vals[1]).item())
probs = {LABELS[i]: float(class_probs[i].item()) for i in range(len(LABELS))}
return pred, confidence, margin, probs
def classify_blip_with_itm(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]:
all_prompts: List[str] = []
prompt_label_idx: List[int] = []
for label_idx, label in enumerate(LABELS):
prompts = BLIP_ITM_PROMPT_BANK[label]
all_prompts.extend(prompts)
prompt_label_idx.extend([label_idx] * len(prompts))
inputs = blip_processor(
images=[image] * len(all_prompts),
text=all_prompts,
return_tensors="pt",
padding=True,
truncation=True,
).to(DEVICE)
with torch.no_grad():
outputs = blip_model(**inputs, use_itm_head=True)
if hasattr(outputs, "itm_score"):
pair_match = torch.softmax(outputs.itm_score, dim=1)[:, 1].detach().cpu()
elif hasattr(outputs, "logits") and outputs.logits.ndim == 2 and outputs.logits.shape[1] == 2:
pair_match = torch.softmax(outputs.logits, dim=1)[:, 1].detach().cpu()
else:
raise RuntimeError("Unexpected BLIP ITM output format.")
label_scores = torch.zeros(len(LABELS), dtype=pair_match.dtype)
for label_idx, _ in enumerate(LABELS):
idxs = [i for i, idx in enumerate(prompt_label_idx) if idx == label_idx]
label_scores[label_idx] = pair_match[idxs].mean()
if float(label_scores.sum()) <= 0:
label_probs = torch.ones_like(label_scores) / len(LABELS)
else:
label_probs = label_scores / label_scores.sum()
pred_idx = int(torch.argmax(label_probs).item())
pred = LABELS[pred_idx]
conf = float(label_probs[pred_idx].item())
evidence = " | ".join([f"{LABELS[i]}={label_probs[i].item():.3f}" for i in range(len(LABELS))])
probs = {LABELS[i]: float(label_probs[i].item()) for i in range(len(LABELS))}
return pred, conf, evidence, probs
def classify_blip_with_vqa(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]:
raw_answers = []
mapped_answers = []
for prompt in BLIP_VQA_PROMPTS:
inputs = blip_processor(images=image, text=prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_ids = blip_model.generate(**inputs, max_new_tokens=8, num_beams=4)
answer = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip().lower()
raw_answers.append(answer)
mapped_answers.append(map_text_to_blip_label(answer))
valid_preds = [pred for pred in mapped_answers if pred != "unknown"]
if len(valid_preds) == 0:
probs = {label: 0.0 for label in LABELS}
return "unknown", 0.0, " | ".join(raw_answers), probs
vote_count = {label: valid_preds.count(label) for label in LABELS}
pred = max(vote_count, key=vote_count.get)
votes = vote_count[pred]
confidence = votes / len(BLIP_VQA_PROMPTS)
probs = {label: vote_count[label] / len(BLIP_VQA_PROMPTS) for label in LABELS}
return pred, float(confidence), " | ".join(raw_answers), probs
def classify_blip_with_caption(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]:
inputs = blip_processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_ids = blip_model.generate(**inputs, max_new_tokens=40, num_beams=4)
caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip().lower()
pred = map_text_to_blip_label(caption)
confidence = 0.50 if pred != "unknown" else 0.0
probs = {label: 0.0 for label in LABELS}
if pred in probs:
probs[pred] = confidence
return pred, float(confidence), caption, probs
def classify_blip(image: Image.Image) -> Tuple[str, float, str, Dict[str, float], str]:
if blip_model is None or blip_processor is None:
return "unavailable", 0.0, "BLIP model not loaded", {label: 0.0 for label in LABELS}, "none"
if blip_mode == "blip-itm":
pred, conf, evidence, probs = classify_blip_with_itm(image)
elif blip_mode == "blip-vqa":
pred, conf, evidence, probs = classify_blip_with_vqa(image)
else:
pred, conf, evidence, probs = classify_blip_with_caption(image)
return pred, conf, evidence, probs, blip_mode
def build_scores_plot(clip_probs: Dict[str, float], blip_probs: Dict[str, float]):
labels = LABELS
clip_vals = [clip_probs.get(lbl, 0.0) for lbl in labels]
blip_vals = [blip_probs.get(lbl, 0.0) for lbl in labels]
x = list(range(len(labels)))
width = 0.36
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.bar([i - width / 2 for i in x], clip_vals, width=width, label="CLIP", color="#1f77b4")
ax.bar([i + width / 2 for i in x], blip_vals, width=width, label="BLIP", color="#ff7f0e")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_ylim(0.0, 1.0)
ax.set_ylabel("Score")
ax.set_title("Score comparison by class")
ax.grid(axis="y", linestyle="--", alpha=0.3)
ax.legend()
fig.tight_layout()
return fig
def run_inference(image: Optional[Image.Image]):
if image is None:
empty_df = pd.DataFrame(columns=["class", "clip_score", "blip_score"])
return (
"Please upload an image.",
empty_df,
None,
"No BLIP evidence yet.",
)
image = image.convert("RGB")
try:
clip_pred, clip_conf, clip_margin, clip_probs = classify_clip(image)
except Exception as e_clip:
empty_df = pd.DataFrame(columns=["class", "clip_score", "blip_score"])
return (
f"CLIP inference failed: {e_clip}",
empty_df,
None,
"BLIP skipped because CLIP failed.",
)
blip_pred, blip_conf, blip_evidence, blip_probs, active_blip_mode = classify_blip(image)
results_df = pd.DataFrame(
{
"class": LABELS,
"clip_score": [clip_probs[lbl] for lbl in LABELS],
"blip_score": [blip_probs.get(lbl, 0.0) for lbl in LABELS],
}
)
fig = build_scores_plot(clip_probs, blip_probs)
summary = (
f"### Prediction summary\n"
f"- **Device:** {DEVICE}\n"
f"- **CLIP:** {clip_pred} (confidence={clip_conf:.3f}, margin={clip_margin:.3f})\n"
f"- **BLIP mode:** {active_blip_mode}\n"
f"- **BLIP:** {blip_pred} (confidence={blip_conf:.3f})"
)
return summary, results_df, fig, blip_evidence
examples = [[path] for path in ensure_real_examples()]
DESCRIPTION = (
"Upload a document image to compare zero-shot predictions from CLIP and BLIP. "
"The app uses the same document classes as your notebook: email, resume, scientific paper. "
"First launch can take a few minutes while models are downloaded."
)
with gr.Blocks(title="Zero-Shot Document Classification with CLIP and BLIP") as demo:
gr.Markdown("# Zero-Shot Document Classification (CLIP + BLIP)")
gr.Markdown(DESCRIPTION)
gr.Markdown(f"**BLIP status:** {blip_load_message}")
with gr.Row():
image_input = gr.Image(type="pil", label="Input document image")
with gr.Row():
run_btn = gr.Button("Run inference", variant="primary")
summary_output = gr.Markdown(label="Summary")
table_output = gr.Dataframe(label="Scores by class", interactive=False)
plot_output = gr.Plot(label="Score visualization")
evidence_output = gr.Textbox(label="BLIP evidence", lines=3)
run_btn.click(
fn=run_inference,
inputs=[image_input],
outputs=[summary_output, table_output, plot_output, evidence_output],
)
gr.Examples(
examples=examples,
inputs=[image_input],
outputs=[summary_output, table_output, plot_output, evidence_output],
fn=run_inference,
cache_examples=False,
)
if __name__ == "__main__":
server_port = int(os.getenv("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=server_port)