Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import torch | |
| from datasets import DatasetDict, load_dataset | |
| from PIL import Image, ImageDraw | |
| from transformers import ( | |
| AutoModel, | |
| AutoProcessor, | |
| BlipForConditionalGeneration, | |
| BlipForImageTextRetrieval, | |
| BlipForQuestionAnswering, | |
| BlipProcessor, | |
| ) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| LABELS = ["email", "resume", "scientific paper"] | |
| CLIP_CHECKPOINT = "openai/clip-vit-base-patch32" | |
| BLIP_ITM_CHECKPOINT = "Salesforce/blip-itm-base-coco" | |
| BLIP_VQA_CHECKPOINT = "Salesforce/blip-vqa-base" | |
| BLIP_CAPTION_CHECKPOINT = "Salesforce/blip-image-captioning-base" | |
| CLIP_PROMPT_BANK = { | |
| "email": [ | |
| "a scanned email with from to subject headers", | |
| "a screenshot of an email message in an inbox", | |
| "an email document with sender recipient and subject lines", | |
| "a digital mail message with formal header fields", | |
| "an office email printout", | |
| ], | |
| "resume": [ | |
| "a professional resume document", | |
| "a curriculum vitae with work experience and education", | |
| "a CV page listing skills and profile", | |
| "a job application resume in document format", | |
| "a candidate resume with sections for experience and education", | |
| ], | |
| "scientific paper": [ | |
| "a scientific paper with abstract and references", | |
| "a research article with sections and citations", | |
| "an academic publication page with dense text", | |
| "a journal paper with methodology and results", | |
| "a scholarly scientific document", | |
| ], | |
| } | |
| BLIP_ITM_PROMPT_BANK = { | |
| "email": [ | |
| "a scanned email document", | |
| "an email message page with sender recipient and subject", | |
| "a digital email printout", | |
| ], | |
| "resume": [ | |
| "a professional resume document", | |
| "a curriculum vitae with skills and work experience", | |
| "a CV page for a job application", | |
| ], | |
| "scientific paper": [ | |
| "a scientific paper page", | |
| "an academic research article with abstract and references", | |
| "a scholarly journal publication", | |
| ], | |
| } | |
| BLIP_SYNONYMS = { | |
| "email": [ | |
| "email", | |
| "e mail", | |
| "mail", | |
| "inbox", | |
| "subject", | |
| "sender", | |
| "recipient", | |
| "message", | |
| ], | |
| "resume": [ | |
| "resume", | |
| "curriculum vitae", | |
| "curriculum", | |
| "cv", | |
| "work experience", | |
| "skills", | |
| "education", | |
| "candidate", | |
| ], | |
| "scientific paper": [ | |
| "scientific paper", | |
| "research paper", | |
| "journal article", | |
| "academic paper", | |
| "scientific publication", | |
| "abstract", | |
| "references", | |
| "methodology", | |
| "introduction", | |
| ], | |
| } | |
| BLIP_VQA_PROMPTS = [ | |
| "Question: Is this document an email, a resume, or a scientific paper? Answer with one option only. Answer:", | |
| "Question: Choose one label for this document: email, resume, scientific paper. Answer:", | |
| "Question: Document type? Options: email, resume, scientific paper. Answer with one label:", | |
| ] | |
| DATASET_ID = "nielsr/rvl_cdip_10_examples_per_class" | |
| PREFERRED_SPLITS = ["test", "validation", "train"] | |
| EXAMPLE_DIR = Path(__file__).parent / "examples" | |
| VALID_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".webp", ".tif", ".tiff"} | |
| MAX_SCAN = 300 | |
| MAX_EXAMPLES_PER_LABEL = 2 | |
| TARGET_ALIASES = { | |
| "email": {"email"}, | |
| "resume": {"resume", "cv"}, | |
| "scientific paper": {"scientificpublication", "scientificreport", "scientificpaper"}, | |
| } | |
| def normalize_label(text: str) -> str: | |
| return "".join(ch for ch in str(text).lower() if ch.isalnum()) | |
| def map_to_target_label(raw_value, label_names=None) -> Optional[str]: | |
| value = raw_value | |
| if label_names is not None: | |
| try: | |
| value = label_names[int(raw_value)] | |
| except Exception: | |
| value = raw_value | |
| normalized = normalize_label(value) | |
| for target_name, aliases in TARGET_ALIASES.items(): | |
| if normalized in aliases: | |
| return target_name | |
| return None | |
| def ensure_placeholder_examples() -> List[str]: | |
| EXAMPLE_DIR.mkdir(parents=True, exist_ok=True) | |
| generated_paths: List[str] = [] | |
| for idx, label in enumerate(LABELS): | |
| filename = EXAMPLE_DIR / f"placeholder_{idx}_{label.replace(' ', '_')}.png" | |
| if not filename.exists(): | |
| img = Image.new("RGB", (900, 1200), color=(245, 245, 245)) | |
| draw = ImageDraw.Draw(img) | |
| draw.rectangle([(40, 40), (860, 1160)], outline=(90, 90, 90), width=4) | |
| draw.text((80, 120), f"Example placeholder: {label}", fill=(20, 20, 20)) | |
| draw.text((80, 220), "Upload a real document image for best results.", fill=(40, 40, 40)) | |
| img.save(filename) | |
| generated_paths.append(str(filename)) | |
| return generated_paths | |
| def ensure_real_examples() -> List[str]: | |
| EXAMPLE_DIR.mkdir(parents=True, exist_ok=True) | |
| existing = sorted([p for p in EXAMPLE_DIR.iterdir() if p.suffix.lower() in VALID_EXT]) | |
| if len(existing) >= 3: | |
| return [str(p) for p in existing] | |
| counts = {label: 0 for label in LABELS} | |
| try: | |
| ds_obj = load_dataset(DATASET_ID) | |
| if isinstance(ds_obj, DatasetDict): | |
| split_name = next((s for s in PREFERRED_SPLITS if s in ds_obj), None) | |
| if split_name is None: | |
| split_name = next(iter(ds_obj.keys())) | |
| selected_ds = ds_obj[split_name] | |
| else: | |
| selected_ds = ds_obj | |
| if hasattr(selected_ds, "shuffle"): | |
| selected_ds = selected_ds.shuffle(seed=42) | |
| label_feature = selected_ds.features.get("label") if hasattr(selected_ds, "features") else None | |
| label_names = getattr(label_feature, "names", None) | |
| for i, row in enumerate(selected_ds): | |
| if i >= MAX_SCAN: | |
| break | |
| if all(v >= MAX_EXAMPLES_PER_LABEL for v in counts.values()): | |
| break | |
| target_label = None | |
| if "label" in row: | |
| target_label = map_to_target_label(row["label"], label_names) | |
| if target_label is None: | |
| for field in ["ground_truth", "category", "label_name", "class"]: | |
| if field in row: | |
| target_label = map_to_target_label(row[field], None) | |
| if target_label is not None: | |
| break | |
| if target_label is None or counts[target_label] >= MAX_EXAMPLES_PER_LABEL: | |
| continue | |
| img_obj = row.get("image") | |
| img = None | |
| try: | |
| if isinstance(img_obj, Image.Image): | |
| img = img_obj.convert("RGB") | |
| elif isinstance(img_obj, dict) and img_obj.get("bytes") is not None: | |
| img = Image.open(io.BytesIO(img_obj["bytes"])).convert("RGB") | |
| elif isinstance(img_obj, dict) and img_obj.get("path") is not None: | |
| img = Image.open(img_obj["path"]).convert("RGB") | |
| elif isinstance(img_obj, str): | |
| img = Image.open(img_obj).convert("RGB") | |
| except Exception: | |
| img = None | |
| if img is None: | |
| continue | |
| out_name = f"example_{target_label.replace(' ', '_')}_{counts[target_label]:02d}.png" | |
| out_path = EXAMPLE_DIR / out_name | |
| img.save(out_path) | |
| counts[target_label] += 1 | |
| downloaded = sorted([p for p in EXAMPLE_DIR.iterdir() if p.suffix.lower() in VALID_EXT]) | |
| if len(downloaded) >= 3: | |
| return [str(p) for p in downloaded] | |
| except Exception: | |
| pass | |
| return ensure_placeholder_examples() | |
| print(f"Loading CLIP on {DEVICE}...") | |
| clip_processor = AutoProcessor.from_pretrained(CLIP_CHECKPOINT, use_fast=False) | |
| clip_model = AutoModel.from_pretrained(CLIP_CHECKPOINT).to(DEVICE) | |
| clip_model.eval() | |
| print("Loading BLIP...") | |
| blip_mode = "none" | |
| blip_processor = None | |
| blip_model = None | |
| blip_load_message = "BLIP not available" | |
| try: | |
| blip_processor = BlipProcessor.from_pretrained(BLIP_ITM_CHECKPOINT) | |
| blip_model = BlipForImageTextRetrieval.from_pretrained(BLIP_ITM_CHECKPOINT).to(DEVICE) | |
| blip_model.eval() | |
| blip_mode = "blip-itm" | |
| blip_load_message = f"BLIP ITM loaded: {BLIP_ITM_CHECKPOINT}" | |
| except Exception as e_itm: | |
| try: | |
| blip_processor = BlipProcessor.from_pretrained(BLIP_VQA_CHECKPOINT) | |
| blip_model = BlipForQuestionAnswering.from_pretrained(BLIP_VQA_CHECKPOINT).to(DEVICE) | |
| blip_model.eval() | |
| blip_mode = "blip-vqa" | |
| blip_load_message = ( | |
| f"BLIP ITM unavailable ({e_itm}). Fallback VQA loaded: {BLIP_VQA_CHECKPOINT}" | |
| ) | |
| except Exception as e_vqa: | |
| try: | |
| blip_processor = BlipProcessor.from_pretrained(BLIP_CAPTION_CHECKPOINT) | |
| blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_CAPTION_CHECKPOINT).to(DEVICE) | |
| blip_model.eval() | |
| blip_mode = "blip-caption" | |
| blip_load_message = ( | |
| f"BLIP ITM/VQA unavailable ({e_itm} | {e_vqa}). Fallback caption loaded: {BLIP_CAPTION_CHECKPOINT}" | |
| ) | |
| except Exception as e_caption: | |
| blip_mode = "none" | |
| blip_load_message = ( | |
| f"BLIP unavailable ({e_itm} | {e_vqa} | {e_caption}). Running only CLIP." | |
| ) | |
| print(blip_load_message) | |
| def normalize_text_for_match(text: str) -> str: | |
| cleaned = "".join(ch if ch.isalnum() else " " for ch in str(text).lower()) | |
| return " ".join(cleaned.split()) | |
| def map_text_to_blip_label(text: str) -> str: | |
| clean = normalize_text_for_match(text) | |
| if not clean: | |
| return "unknown" | |
| scores = {label: 0 for label in LABELS} | |
| for label, terms in BLIP_SYNONYMS.items(): | |
| for term in terms: | |
| norm_term = normalize_text_for_match(term) | |
| if norm_term and norm_term in clean: | |
| scores[label] += 1 | |
| best_label = max(LABELS, key=lambda lbl: scores[lbl]) | |
| return best_label if scores[best_label] > 0 else "unknown" | |
| def classify_clip(image: Image.Image) -> Tuple[str, float, float, Dict[str, float]]: | |
| all_prompts: List[str] = [] | |
| prompt_label_idx: List[int] = [] | |
| for label_idx, label in enumerate(LABELS): | |
| prompts = CLIP_PROMPT_BANK[label] | |
| all_prompts.extend(prompts) | |
| prompt_label_idx.extend([label_idx] * len(prompts)) | |
| clip_inputs = clip_processor( | |
| text=all_prompts, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| clip_outputs = clip_model(**clip_inputs) | |
| if not hasattr(clip_outputs, "logits_per_image"): | |
| raise RuntimeError("CLIP output does not contain logits_per_image.") | |
| prompt_probs = clip_outputs.logits_per_image.softmax(dim=1)[0].detach().cpu() | |
| class_probs = torch.zeros(len(LABELS), dtype=prompt_probs.dtype) | |
| for label_idx, _ in enumerate(LABELS): | |
| idxs = [i for i, idx in enumerate(prompt_label_idx) if idx == label_idx] | |
| class_probs[label_idx] = prompt_probs[idxs].mean() | |
| class_probs = class_probs / class_probs.sum() | |
| top2_vals, top2_idx = torch.topk(class_probs, k=2) | |
| pred_idx = int(top2_idx[0].item()) | |
| pred = LABELS[pred_idx] | |
| confidence = float(top2_vals[0].item()) | |
| margin = float((top2_vals[0] - top2_vals[1]).item()) | |
| probs = {LABELS[i]: float(class_probs[i].item()) for i in range(len(LABELS))} | |
| return pred, confidence, margin, probs | |
| def classify_blip_with_itm(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]: | |
| all_prompts: List[str] = [] | |
| prompt_label_idx: List[int] = [] | |
| for label_idx, label in enumerate(LABELS): | |
| prompts = BLIP_ITM_PROMPT_BANK[label] | |
| all_prompts.extend(prompts) | |
| prompt_label_idx.extend([label_idx] * len(prompts)) | |
| inputs = blip_processor( | |
| images=[image] * len(all_prompts), | |
| text=all_prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = blip_model(**inputs, use_itm_head=True) | |
| if hasattr(outputs, "itm_score"): | |
| pair_match = torch.softmax(outputs.itm_score, dim=1)[:, 1].detach().cpu() | |
| elif hasattr(outputs, "logits") and outputs.logits.ndim == 2 and outputs.logits.shape[1] == 2: | |
| pair_match = torch.softmax(outputs.logits, dim=1)[:, 1].detach().cpu() | |
| else: | |
| raise RuntimeError("Unexpected BLIP ITM output format.") | |
| label_scores = torch.zeros(len(LABELS), dtype=pair_match.dtype) | |
| for label_idx, _ in enumerate(LABELS): | |
| idxs = [i for i, idx in enumerate(prompt_label_idx) if idx == label_idx] | |
| label_scores[label_idx] = pair_match[idxs].mean() | |
| if float(label_scores.sum()) <= 0: | |
| label_probs = torch.ones_like(label_scores) / len(LABELS) | |
| else: | |
| label_probs = label_scores / label_scores.sum() | |
| pred_idx = int(torch.argmax(label_probs).item()) | |
| pred = LABELS[pred_idx] | |
| conf = float(label_probs[pred_idx].item()) | |
| evidence = " | ".join([f"{LABELS[i]}={label_probs[i].item():.3f}" for i in range(len(LABELS))]) | |
| probs = {LABELS[i]: float(label_probs[i].item()) for i in range(len(LABELS))} | |
| return pred, conf, evidence, probs | |
| def classify_blip_with_vqa(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]: | |
| raw_answers = [] | |
| mapped_answers = [] | |
| for prompt in BLIP_VQA_PROMPTS: | |
| inputs = blip_processor(images=image, text=prompt, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = blip_model.generate(**inputs, max_new_tokens=8, num_beams=4) | |
| answer = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip().lower() | |
| raw_answers.append(answer) | |
| mapped_answers.append(map_text_to_blip_label(answer)) | |
| valid_preds = [pred for pred in mapped_answers if pred != "unknown"] | |
| if len(valid_preds) == 0: | |
| probs = {label: 0.0 for label in LABELS} | |
| return "unknown", 0.0, " | ".join(raw_answers), probs | |
| vote_count = {label: valid_preds.count(label) for label in LABELS} | |
| pred = max(vote_count, key=vote_count.get) | |
| votes = vote_count[pred] | |
| confidence = votes / len(BLIP_VQA_PROMPTS) | |
| probs = {label: vote_count[label] / len(BLIP_VQA_PROMPTS) for label in LABELS} | |
| return pred, float(confidence), " | ".join(raw_answers), probs | |
| def classify_blip_with_caption(image: Image.Image) -> Tuple[str, float, str, Dict[str, float]]: | |
| inputs = blip_processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = blip_model.generate(**inputs, max_new_tokens=40, num_beams=4) | |
| caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip().lower() | |
| pred = map_text_to_blip_label(caption) | |
| confidence = 0.50 if pred != "unknown" else 0.0 | |
| probs = {label: 0.0 for label in LABELS} | |
| if pred in probs: | |
| probs[pred] = confidence | |
| return pred, float(confidence), caption, probs | |
| def classify_blip(image: Image.Image) -> Tuple[str, float, str, Dict[str, float], str]: | |
| if blip_model is None or blip_processor is None: | |
| return "unavailable", 0.0, "BLIP model not loaded", {label: 0.0 for label in LABELS}, "none" | |
| if blip_mode == "blip-itm": | |
| pred, conf, evidence, probs = classify_blip_with_itm(image) | |
| elif blip_mode == "blip-vqa": | |
| pred, conf, evidence, probs = classify_blip_with_vqa(image) | |
| else: | |
| pred, conf, evidence, probs = classify_blip_with_caption(image) | |
| return pred, conf, evidence, probs, blip_mode | |
| def build_scores_plot(clip_probs: Dict[str, float], blip_probs: Dict[str, float]): | |
| labels = LABELS | |
| clip_vals = [clip_probs.get(lbl, 0.0) for lbl in labels] | |
| blip_vals = [blip_probs.get(lbl, 0.0) for lbl in labels] | |
| x = list(range(len(labels))) | |
| width = 0.36 | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.bar([i - width / 2 for i in x], clip_vals, width=width, label="CLIP", color="#1f77b4") | |
| ax.bar([i + width / 2 for i in x], blip_vals, width=width, label="BLIP", color="#ff7f0e") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(labels) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_ylabel("Score") | |
| ax.set_title("Score comparison by class") | |
| ax.grid(axis="y", linestyle="--", alpha=0.3) | |
| ax.legend() | |
| fig.tight_layout() | |
| return fig | |
| def run_inference(image: Optional[Image.Image]): | |
| if image is None: | |
| empty_df = pd.DataFrame(columns=["class", "clip_score", "blip_score"]) | |
| return ( | |
| "Please upload an image.", | |
| empty_df, | |
| None, | |
| "No BLIP evidence yet.", | |
| ) | |
| image = image.convert("RGB") | |
| try: | |
| clip_pred, clip_conf, clip_margin, clip_probs = classify_clip(image) | |
| except Exception as e_clip: | |
| empty_df = pd.DataFrame(columns=["class", "clip_score", "blip_score"]) | |
| return ( | |
| f"CLIP inference failed: {e_clip}", | |
| empty_df, | |
| None, | |
| "BLIP skipped because CLIP failed.", | |
| ) | |
| blip_pred, blip_conf, blip_evidence, blip_probs, active_blip_mode = classify_blip(image) | |
| results_df = pd.DataFrame( | |
| { | |
| "class": LABELS, | |
| "clip_score": [clip_probs[lbl] for lbl in LABELS], | |
| "blip_score": [blip_probs.get(lbl, 0.0) for lbl in LABELS], | |
| } | |
| ) | |
| fig = build_scores_plot(clip_probs, blip_probs) | |
| summary = ( | |
| f"### Prediction summary\n" | |
| f"- **Device:** {DEVICE}\n" | |
| f"- **CLIP:** {clip_pred} (confidence={clip_conf:.3f}, margin={clip_margin:.3f})\n" | |
| f"- **BLIP mode:** {active_blip_mode}\n" | |
| f"- **BLIP:** {blip_pred} (confidence={blip_conf:.3f})" | |
| ) | |
| return summary, results_df, fig, blip_evidence | |
| examples = [[path] for path in ensure_real_examples()] | |
| DESCRIPTION = ( | |
| "Upload a document image to compare zero-shot predictions from CLIP and BLIP. " | |
| "The app uses the same document classes as your notebook: email, resume, scientific paper. " | |
| "First launch can take a few minutes while models are downloaded." | |
| ) | |
| with gr.Blocks(title="Zero-Shot Document Classification with CLIP and BLIP") as demo: | |
| gr.Markdown("# Zero-Shot Document Classification (CLIP + BLIP)") | |
| gr.Markdown(DESCRIPTION) | |
| gr.Markdown(f"**BLIP status:** {blip_load_message}") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Input document image") | |
| with gr.Row(): | |
| run_btn = gr.Button("Run inference", variant="primary") | |
| summary_output = gr.Markdown(label="Summary") | |
| table_output = gr.Dataframe(label="Scores by class", interactive=False) | |
| plot_output = gr.Plot(label="Score visualization") | |
| evidence_output = gr.Textbox(label="BLIP evidence", lines=3) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[image_input], | |
| outputs=[summary_output, table_output, plot_output, evidence_output], | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input], | |
| outputs=[summary_output, table_output, plot_output, evidence_output], | |
| fn=run_inference, | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| server_port = int(os.getenv("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=server_port) | |