indic-clip / app.py
numb3r33's picture
change model embed dim.
16cb310
"""
Gradio Application for Indic-CLIP Demo
@description
This script launches an interactive Gradio interface to demonstrate the capabilities
of the trained Indic-CLIP model. It allows users to perform:
- Image-to-Text Retrieval: Find relevant text descriptions for an input image.
- Text-to-Image Retrieval: Find relevant images for an input text query (Hindi/Sanskrit).
- Zero-Shot Classification: Classify an image based on user-provided text labels without
specific training on those labels.
The application loads a pre-trained Indic-CLIP model checkpoint and its associated tokenizer.
It uses a small, pre-defined gallery of text descriptions and sample images for the retrieval demos.
@dependencies
- gradio: For building the web interface.
- torch: Core PyTorch library.
- PIL (Pillow): For image processing.
- indic_clip library: Contains the core model, data handling, and inference logic.
(Needs to be installed in the environment where this app runs).
- transformers, timm, sentencepiece, fastai: Dependencies of indic_clip.
@environment_variables (Optional - for Hugging Face Spaces)
- CHECKPOINT_FILENAME: Name of the checkpoint file (e.g., 'best_valid_loss.pth'). Default: 'best_valid_loss.pth'.
- CHECKPOINT_DIR: Path to the directory containing checkpoints, relative to app root. Default: 'models/checkpoints'.
- TOKENIZER_DIR: Path to the directory containing the tokenizer files, relative to app root. Default: 'models/tokenizer'.
- SAMPLE_DIR: Path to the directory containing sample images, relative to app root. Default: 'data/samples'.
- IMAGE_SIZE: Image size expected by the model. Default: 224.
- TOP_K: Number of retrieval results to display. Default: 5.
- MODEL_VISION_BACKBONE: Name of the vision backbone used during training. Default: 'resnet50'.
- MODEL_TEXT_BACKBONE: Name of the text backbone used during training. Default: 'ai4bharat/indic-bert'.
- MODEL_EMBED_DIM: Embedding dimension of the trained model. Default: 512.
@notes
- Ensure the specified checkpoint file, tokenizer files, and sample images exist in the expected paths
within the Hugging Face Space repository or are downloaded/mounted correctly.
- Model configuration (embed_dim, backbones) MUST match the loaded checkpoint.
- Pre-encoding gallery features significantly improves demo speed. If loading fails, retrieval will be slower.
"""
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import random
import logging
from pathlib import Path
import sys
from typing import List, Tuple, Dict, Optional
# --- Project Imports ---
# Try importing assuming 'indic_clip' is installed
try:
from indic_clip.inference import (
load_indic_clip_model,
extract_image_features,
extract_text_features,
compute_similarity
)
from indic_clip.core import (
get_logger, setup_logging, CHECKPOINT_PATH as DEFAULT_CHECKPOINT_PATH,
TOKENIZER_PATH as DEFAULT_TOKENIZER_PATH, PRETRAINED_TOKENIZER_NAME,
DEFAULT_EMBED_DIM, DEFAULT_IMAGE_SIZE, PROJECT_ROOT as DEFAULT_PROJECT_ROOT
)
from indic_clip.data.tokenization import IndicBERTTokenizer
from indic_clip.model.clip import IndicCLIP
from indic_clip.evaluation.benchmarks import (
DEFAULT_PROMPT_TEMPLATES_HI, DEFAULT_PROMPT_TEMPLATES_SA, DEFAULT_PROMPT_TEMPLATES_EN
)
PROJECT_MODULES_LOADED = True
except ImportError as e:
PROJECT_MODULES_LOADED = False
print(f"Error importing indic_clip modules: {e}")
print("Falling back to dummy definitions or exiting if critical components missing.")
# Define dummy logger if core is missing
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("indic_clip_app_fallback")
def get_logger(name): return logging.getLogger(name)
def setup_logging(): pass
DEFAULT_IMAGE_SIZE = 224
DEFAULT_EMBED_DIM = 512
PRETRAINED_TOKENIZER_NAME = "ai4bharat/indic-bert" # Fallback
DEFAULT_CHECKPOINT_PATH = Path("./models/checkpoints") # Fallback
DEFAULT_TOKENIZER_PATH = Path("./models/tokenizer") # Fallback
DEFAULT_PROJECT_ROOT = Path(".") # Fallback
# Dummies for functionality - these will likely cause errors later if used
class IndicCLIP(torch.nn.Module): pass
class IndicBERTTokenizer: pass
def load_indic_clip_model(*args, **kwargs): raise RuntimeError("Model loading failed - indic_clip not found")
def extract_image_features(*args, **kwargs): raise RuntimeError("Inference failed - indic_clip not found")
def extract_text_features(*args, **kwargs): raise RuntimeError("Inference failed - indic_clip not found")
def compute_similarity(*args, **kwargs): raise RuntimeError("Inference failed - indic_clip not found")
DEFAULT_PROMPT_TEMPLATES_EN = ["a photo of a {}"] # Fallback
# --- Configuration ---
setup_logging()
logger = get_logger("indic_clip_app")
CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", 'best_valid_loss.pth')
HF_SPACE_ROOT = Path(".") # Assume current dir is repo root in HF Spaces
CHECKPOINT_DIR = HF_SPACE_ROOT / os.getenv("CHECKPOINT_DIR", DEFAULT_CHECKPOINT_PATH)
TOKENIZER_DIR = HF_SPACE_ROOT / os.getenv("TOKENIZER_DIR", DEFAULT_TOKENIZER_PATH)
SAMPLE_DIR = HF_SPACE_ROOT / os.getenv("SAMPLE_DIR", DEFAULT_PROJECT_ROOT / "data/samples") # Use project root default if not set
CHECKPOINT_FILE_PATH = CHECKPOINT_DIR / CHECKPOINT_FILENAME
TOKENIZER_DIR_PATH = TOKENIZER_DIR
SAMPLE_IMAGE_DIR = SAMPLE_DIR
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = int(os.getenv("IMAGE_SIZE", DEFAULT_IMAGE_SIZE))
TOP_K = int(os.getenv("TOP_K", 5))
MODEL_VISION_BACKBONE = os.getenv("MODEL_VISION_BACKBONE", "resnet50")
MODEL_TEXT_BACKBONE = os.getenv("MODEL_TEXT_BACKBONE", PRETRAINED_TOKENIZER_NAME)
MODEL_EMBED_DIM = int(os.getenv("MODEL_EMBED_DIM", 512)) # Get default from core
logger.info(f"--- App Configuration ---")
logger.info(f"Device: {DEVICE}")
logger.info(f"Checkpoint Path: {CHECKPOINT_FILE_PATH}")
logger.info(f"Tokenizer Path: {TOKENIZER_DIR_PATH}")
logger.info(f"Sample Image Dir: {SAMPLE_IMAGE_DIR}")
logger.info(f"Image Size: {IMAGE_SIZE}")
logger.info(f"Vision Backbone: {MODEL_VISION_BACKBONE}")
logger.info(f"Text Backbone: {MODEL_TEXT_BACKBONE}")
logger.info(f"Embedding Dim: {MODEL_EMBED_DIM}")
logger.info(f"Top K Retrieval: {TOP_K}")
logger.info(f"Project Modules Loaded: {PROJECT_MODULES_LOADED}")
logger.info(f"-------------------------")
# --- Load Tokenizer ---
tokenizer = None
if PROJECT_MODULES_LOADED and TOKENIZER_DIR_PATH.exists():
try:
tokenizer = IndicBERTTokenizer.load_tokenizer(TOKENIZER_DIR_PATH)
logger.info(f"Tokenizer loaded successfully from {TOKENIZER_DIR_PATH}")
except Exception as e:
logger.error(f"Error loading tokenizer from {TOKENIZER_DIR_PATH}: {e}", exc_info=True)
elif PROJECT_MODULES_LOADED:
logger.error(f"Tokenizer directory not found at {TOKENIZER_DIR_PATH}. Cannot start app.")
# --- Load Model ---
model: IndicCLIP = None
if PROJECT_MODULES_LOADED and tokenizer is not None:
MODEL_CONFIGURATION = {
'embed_dim': MODEL_EMBED_DIM,
'vision_model_name': MODEL_VISION_BACKBONE,
'vision_pretrained': False, # Doesn't affect loading state_dict
'text_model_name': MODEL_TEXT_BACKBONE,
'text_pretrained': False,
'tokenizer': tokenizer
}
if CHECKPOINT_FILE_PATH.is_file():
try:
model = load_indic_clip_model(
checkpoint_path=CHECKPOINT_FILE_PATH,
model_config=MODEL_CONFIGURATION,
device=DEVICE
)
logger.info(f"Model loaded successfully from {CHECKPOINT_FILE_PATH} to device {DEVICE}.")
except Exception as e:
logger.error(f"Error loading model from {CHECKPOINT_FILE_PATH}: {e}", exc_info=True)
else:
logger.error(f"Checkpoint file not found: {CHECKPOINT_FILE_PATH}")
else:
logger.error("Tokenizer not loaded or project modules missing, skipping model load.")
if model is None:
logger.critical("Model could not be loaded. The application cannot function properly.")
# --- Sample Gallery Data ---
TEXT_GALLERY = [
"एक लड़का फुटबॉल खेल रहा है।",
"समुद्र तट पर सूर्यास्त।",
"एक बिल्ली सोफे पर सो रही है।",
"पारंपरिक साड़ी पहने एक महिला।",
"एक मंदिर का प्रवेश द्वार।",
"देवता गणेश की मूर्ति।",
"एक व्यस्त भारतीय बाज़ार।",
"लैपटॉप पर काम करता हुआ व्यक्ति।",
"मेज़ पर रखी किताबों का ढेर।",
"एक लाल रंग की स्पोर्ट्स कार।"
]
# Define image filenames expected in SAMPLE_IMAGE_DIR
IMAGE_GALLERY_FILENAMES = [
'cat.jpg',
'dog_park.jpg',
'sunset_beach.jpg',
'woman_saree.jpg',
'temple.jpg',
'ganesh.jpg',
'market.jpg',
'laptop.jpg',
'books.jpg',
'car.jpg'
]
# Verify image existence and create lists for display and processing
IMAGE_GALLERY_DISPLAY = [] # List of string paths for Gradio
valid_image_gallery_files = [] # List of Path objects for processing
SAMPLE_IMAGE_DIR.mkdir(parents=True, exist_ok=True) # Ensure dir exists
for filename in IMAGE_GALLERY_FILENAMES:
full_img_path = SAMPLE_IMAGE_DIR / filename
if full_img_path.is_file():
IMAGE_GALLERY_DISPLAY.append(str(full_img_path)) # Use path relative to app root
valid_image_gallery_files.append(full_img_path)
else:
logger.warning(f"Sample image not found: {full_img_path}")
logger.info(f"Found {len(valid_image_gallery_files)} valid sample images in {SAMPLE_IMAGE_DIR}")
# --- Pre-encode Gallery Features ---
text_gallery_features: Optional[torch.Tensor] = None
image_gallery_features: Optional[torch.Tensor] = None
if model is not None and tokenizer is not None:
logger.info("Attempting to pre-encode gallery features...")
try:
text_gallery_features = extract_text_features(model, tokenizer, TEXT_GALLERY, device=DEVICE)
logger.info(f"Encoded {len(TEXT_GALLERY)} text gallery items. Shape: {text_gallery_features.shape if text_gallery_features is not None else 'None'}")
except Exception as e:
logger.error(f"Failed to pre-encode text gallery: {e}", exc_info=True)
if valid_image_gallery_files:
try:
image_gallery_features = extract_image_features(model, valid_image_gallery_files, img_size=IMAGE_SIZE, device=DEVICE)
logger.info(f"Encoded {len(valid_image_gallery_files)} image gallery items. Shape: {image_gallery_features.shape if image_gallery_features is not None else 'None'}")
except Exception as e:
logger.error(f"Failed to pre-encode image gallery: {e}", exc_info=True)
else:
logger.warning("Model or tokenizer not loaded, cannot pre-encode gallery features.")
# --- Gradio Interface Functions ---
def predict_text_from_image(image_input: Image.Image) -> str:
"""Gradio interface function for Image-to-Text retrieval."""
if model is None or tokenizer is None or text_gallery_features is None:
return "Error: Model, tokenizer, or text gallery features not loaded. Cannot perform retrieval."
if image_input is None:
return "Error: Please upload an image."
logger.info("Performing Image-to-Text retrieval...")
try:
img_feat = extract_image_features(model, image_input, img_size=IMAGE_SIZE, device=DEVICE)
if img_feat is None or img_feat.nelement() == 0:
return "Error: Could not extract features from the image."
similarity = compute_similarity(model, img_feat, text_gallery_features)
scores, indices = torch.topk(similarity.squeeze(0), k=min(TOP_K, len(TEXT_GALLERY)), dim=-1)
results = "\n".join([f"{scores[i].item():.4f}: {TEXT_GALLERY[indices[i].item()]}" for i in range(len(indices))])
logger.info("Image-to-Text retrieval successful.")
return results
except Exception as e:
logger.error(f"Error in predict_text_from_image: {e}", exc_info=True)
return f"An error occurred during text retrieval: {e}"
def predict_image_from_text(text_input: str) -> List[Tuple[str, str]]:
"""Gradio interface function for Text-to-Image retrieval."""
# Provide a placeholder image path for errors
error_img_placeholder = "https://dummyimage.com/150x150/ff0000/ffffff.png&text=Error"
if model is None or tokenizer is None or image_gallery_features is None or not IMAGE_GALLERY_DISPLAY:
return [(error_img_placeholder, "Error: Model, tokenizer, or image gallery features not loaded.")]
if not text_input or not text_input.strip():
return [(error_img_placeholder, "Error: Please enter text query.")]
logger.info(f"Performing Text-to-Image retrieval for query: '{text_input}'")
try:
txt_feat = extract_text_features(model, tokenizer, text_input, device=DEVICE)
if txt_feat is None or txt_feat.nelement() == 0:
return [(error_img_placeholder, "Error: Could not extract features from the text.")]
similarity = compute_similarity(model, image_gallery_features, txt_feat) # T2I order: (Img, Txt) -> [N_img, N_txt=1]
scores, indices = torch.topk(similarity.squeeze(-1), k=min(TOP_K, len(IMAGE_GALLERY_DISPLAY)), dim=0)
results = []
for i in range(len(indices)):
img_index = indices[i].item()
# Use the verified display path list
display_path = IMAGE_GALLERY_DISPLAY[img_index]
score = scores[i].item()
caption = f"{score:.4f}: {Path(display_path).name}" # Use filename in caption
results.append((display_path, caption))
logger.info("Text-to-Image retrieval successful.")
return results
except Exception as e:
logger.error(f"Error in predict_image_from_text: {e}", exc_info=True)
return [(error_img_placeholder, f"An error occurred during image retrieval: {e}")]
def predict_zero_shot(image_input: Image.Image, candidate_labels_text: str) -> Dict[str, float]:
"""Gradio interface function for Zero-Shot Classification."""
if model is None or tokenizer is None:
return {"Error": 1.0, "Message": "Model or tokenizer not loaded."}
if image_input is None:
return {"Error": 1.0, "Message": "Please upload an image."}
if not candidate_labels_text or not candidate_labels_text.strip():
return {"Error": 1.0, "Message": "Please enter candidate labels (comma-separated)."}
logger.info(f"Performing Zero-Shot classification for labels: '{candidate_labels_text}'")
try:
class_names = [label.strip() for label in candidate_labels_text.split(',') if label.strip()]
if not class_names:
return {"Error": 1.0, "Message": "Invalid label format. Enter comma-separated labels."}
# Using English templates as default for broader applicability
templates = DEFAULT_PROMPT_TEMPLATES_EN
img_feat = extract_image_features(model, image_input, img_size=IMAGE_SIZE, device=DEVICE)
if img_feat is None or img_feat.nelement() == 0:
return {"Error": 1.0, "Message": "Could not extract features from image."}
all_prompts = [tmpl.format(cn) for tmpl in templates for cn in class_names]
text_embeddings = extract_text_features(model, tokenizer, all_prompts, device=DEVICE)
if text_embeddings is None or text_embeddings.nelement() == 0:
return {"Error": 1.0, "Message": "Could not extract features from text labels."}
if len(templates) > 1:
num_classes = len(class_names)
text_embeddings = text_embeddings.view(len(templates), num_classes, -1).mean(dim=0)
# Ensure final normalization
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
# Compute similarity
similarity = compute_similarity(model, img_feat, text_embeddings).squeeze()
# Apply softmax to get probabilities
# Add temperature scaling if needed (usually included in compute_similarity via logit_scale)
probs = F.softmax(similarity, dim=-1)
results = {class_names[i]: probs[i].item() for i in range(len(class_names))}
logger.info(f"Zero-Shot classification successful. Results: {results}")
return results
except Exception as e:
logger.error(f"Error in predict_zero_shot: {e}", exc_info=True)
return {"Error": 1.0, "Message": f"An error occurred: {e}"}
# --- Gradio Interface Definition ---
css = """
.gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
input[type='range'] { accent-color: black; }
.dark input[type='range'] { accent-color: #dfdqdq; }
.container { max-width: 1100px; margin: auto; padding-top: 1.5rem; }
#gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; }
#gallery>div>.h-full { min-height: 20rem; }
.details:hover { text-decoration: underline; }
.feedback { font-size: 0.8rem; margin-bottom: 5px; }
.feedback textarea { font-size: 0.8rem; }
.feedback button { margin: 0; }
.gradio-container { max-width: 1140px !important; }
footer {visibility: hidden}
"""
block = gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue"))
with block:
gr.Markdown(
"""
<div style="text-align: center; max-width: 1000px; margin: 20px auto;">
<h1 style="font-weight: 900; font-size: 3rem;">
Indic-CLIP <span style="font-size: 1.5rem">🖼️<->📝</span>
</h1>
<p style="margin-bottom: 10px; font-size: 94%">
Multimodal Vision-Language Model for Indic Languages (Hindi/Sanskrit)
</p>
<p>Provide an image or text to retrieve corresponding matches, or perform zero-shot classification.</p>
<p><strong>Note:</strong> This demo uses a small, fixed gallery for retrieval. Model checkpoint: <code>{}</code></p>
</div>
""".format(CHECKPOINT_FILENAME)
)
with gr.Tabs():
with gr.TabItem("🖼️ Image-to-Text Retrieval"):
with gr.Row(equal_height=True):
with gr.Column():
input_image_i2t = gr.Image(type="pil", label="Input Image")
submit_i2t = gr.Button("Retrieve Text", variant="primary")
with gr.Column():
output_text_i2t = gr.Textbox(lines=TOP_K, label=f"Top {TOP_K} Text Matches (Score: Text)", interactive=False)
gr.Examples(
examples=IMAGE_GALLERY_DISPLAY[:min(5, len(IMAGE_GALLERY_DISPLAY))], # Use verified paths
inputs=input_image_i2t,
label="Sample Images (Click to Load)"
)
with gr.TabItem("📝 Text-to-Image Retrieval"):
with gr.Row(equal_height=True):
with gr.Column():
input_text_t2i = gr.Textbox(label="Input Text (e.g., Hindi, Sanskrit, English)", placeholder="उदाहरण: एक बिल्ली सोफे पर सो रही है।")
submit_t2i = gr.Button("Retrieve Images", variant="primary")
with gr.Column():
# CORRECTED: Pass style args directly to constructor
output_gallery_t2i = gr.Gallery(
label=f"Top {TOP_K} Image Matches (Score: Filename)",
show_label=True,
columns=TOP_K,
height="auto",
object_fit="contain"
)
gr.Examples(
examples=TEXT_GALLERY[:min(5, len(TEXT_GALLERY))],
inputs=input_text_t2i,
label="Sample Text Queries (Click to Load)"
)
with gr.TabItem("🏷️ Zero-Shot Classification"):
with gr.Row(equal_height=True):
with gr.Column():
input_image_zs = gr.Image(type="pil", label="Input Image")
with gr.Column():
candidate_labels_zs = gr.Textbox(label="Candidate Labels (Comma-separated)", placeholder="e.g., बिल्ली, कुत्ता, पक्षी, कार, मंदिर")
submit_zs = gr.Button("Classify Image", variant="primary")
output_labels_zs = gr.Label(num_top_classes=max(3, len(candidate_labels_zs.value.split(',')) if candidate_labels_zs.value else 3), label="Classification Probabilities")
gr.Examples(
examples=[
[img_path, "बिल्ली, कुत्ता, पक्षी"] for img_path in IMAGE_GALLERY_DISPLAY if "cat" in str(img_path)
] + [
[img_path, "साड़ी, कुर्ता, पोशाक"] for img_path in IMAGE_GALLERY_DISPLAY if "saree" in str(img_path)
] + [
[img_path, "मंदिर, मस्जिद, चर्च"] for img_path in IMAGE_GALLERY_DISPLAY if "temple" in str(img_path)
] + [
[img_path, "कार, बस, मोटरबाइक"] for img_path in IMAGE_GALLERY_DISPLAY if "car" in str(img_path)
]
,
inputs=[input_image_zs, candidate_labels_zs],
outputs=output_labels_zs,
label="Sample Images and Labels (Click to Load)",
cache_examples=False # Avoid caching issues with file paths
)
# Define button click actions
submit_i2t.click(
predict_text_from_image,
inputs=[input_image_i2t],
outputs=[output_text_i2t],
api_name="image_to_text_retrieval"
)
submit_t2i.click(
predict_image_from_text,
inputs=[input_text_t2i],
outputs=[output_gallery_t2i],
api_name="text_to_image_retrieval"
)
submit_zs.click(
predict_zero_shot,
inputs=[input_image_zs, candidate_labels_zs],
outputs=[output_labels_zs],
api_name="zero_shot_classification"
)
# --- Launch ---
if __name__ == "__main__":
if not PROJECT_MODULES_LOADED:
print("\nERROR: Indic-CLIP project modules could not be loaded.")
print("Please ensure the library is installed correctly ('pip install -e .')")
print("Cannot launch Gradio app.\n")
# Display error in Gradio if possible
try:
with block:
gr.Markdown("<h2 style='color:red;'>ERROR: Indic-CLIP project code not found. Cannot launch application.</h2>")
block.launch()
except NameError: # If block wasn't defined due to earlier errors
pass
elif model is None:
logger.error("Model is None. Cannot launch Gradio app.")
# Display error in Gradio
with block:
gr.Markdown("<h2 style='color:red;'>ERROR: Model failed to load. Application cannot start. Please check logs and ensure checkpoint/tokenizer exist.</h2>")
block.launch()
else:
logger.info("Launching Gradio interface...")
block.launch() # share=True for public link