from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration from sentence_transformers import SentenceTransformer import faiss import numpy as np import json import spacy import time # Load models and resources clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") text_encoder = SentenceTransformer('all-MiniLM-L6-v2') tokenizer = T5Tokenizer.from_pretrained("t5-small") generator = T5ForConditionalGeneration.from_pretrained("t5-small") try: nlp = spacy.load("en_core_web_sm") except: import spacy.cli spacy.cli.download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") # Load FAISS index and captions faiss_index = faiss.read_index("./faiss_index.idx") with open("./captions.json", "r", encoding="utf-8") as f: captions = json.load(f) def extract_image_features(image): """ Extract image features using CLIP model. Input: PIL Image or image path (str). Output: Normalized image embedding (numpy array). """ # try: # if isinstance(image, str): # image = Image.open(image).convert("RGB") # else: # image = image.convert("RGB") # inputs = clip_processor(images=image, return_tensors="pt") # with torch.no_grad(): # features = clip_model.get_image_features(**inputs) # features = torch.nn.functional.normalize(features, p=2, dim=-1) # return features.squeeze(0).cpu().numpy().astype("float32") # except Exception as e: # print(f"Error extracting features: {e}") # return None try: # Convert NumPy array to PIL if needed if isinstance(image, np.ndarray): image = Image.fromarray(image.astype("uint8")).convert("RGB") elif isinstance(image, str): image = Image.open(image).convert("RGB") else: image = image.convert("RGB") inputs = clip_processor(images=image, return_tensors="pt") with torch.no_grad(): features = clip_model.get_image_features(**inputs) features = torch.nn.functional.normalize(features, p=2, dim=-1) return features.squeeze(0).cpu().numpy().astype("float32") except Exception as e: print(f"Error extracting features: {e}") return None def retrieve_similar_captions(image_embedding, k=5): """ Retrieve k most similar captions using FAISS index. Input: Image embedding (numpy array). Output: List of captions. """ if image_embedding.ndim == 1: image_embedding = image_embedding.reshape(1, -1) D, I = faiss_index.search(image_embedding, k) return [captions[i] for i in I[0]] def extract_location_names(texts): """ Extract location names from captions using spaCy. Input: List of captions. Output: List of unique location names. """ names = [] for text in texts: doc = nlp(text) for ent in doc.ents: if ent.label_ in ["GPE", "LOC", "FAC"]: names.append(ent.text) return list(set(names)) def generate_caption_from_retrieved(retrieved_captions): """ Generate a caption from retrieved captions using T5. Input: List of retrieved captions. Output: Generated caption (str). """ locations = extract_location_names(retrieved_captions) location_hint = f"The place might be: {', '.join(locations)}. " if locations else "" prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True) outputs = generator.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=300, num_beams=5, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def generate_rag_caption(image): """ Generate a RAG-based caption for an image. Input: PIL Image or image path (str). Output: Caption (str). """ try: start = time.time() embedding = extract_image_features(image) print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s") if embedding is None: return "Failed to process image." start = time.time() retrieved = retrieve_similar_captions(embedding, k=5) print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s") if not retrieved: return "No similar captions found." start = time.time() caption = generate_caption_from_retrieved(retrieved) print("⏱️ Caption generation:", round(time.time() - start, 2), "s") return caption except Exception as e: print(f"Error in RAG captioning: {e}") return "Something went wrong during caption generation."