Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import json | |
| import spacy | |
| import time | |
| # Load models and resources | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| text_encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
| generator = T5ForConditionalGeneration.from_pretrained("t5-small") | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except: | |
| import spacy.cli | |
| spacy.cli.download("en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| # Load FAISS index and captions | |
| faiss_index = faiss.read_index("./faiss_index.idx") | |
| with open("./captions.json", "r", encoding="utf-8") as f: | |
| captions = json.load(f) | |
| def extract_image_features(image): | |
| """ | |
| Extract image features using CLIP model. | |
| Input: PIL Image or image path (str). | |
| Output: Normalized image embedding (numpy array). | |
| """ | |
| # try: | |
| # if isinstance(image, str): | |
| # image = Image.open(image).convert("RGB") | |
| # else: | |
| # image = image.convert("RGB") | |
| # inputs = clip_processor(images=image, return_tensors="pt") | |
| # with torch.no_grad(): | |
| # features = clip_model.get_image_features(**inputs) | |
| # features = torch.nn.functional.normalize(features, p=2, dim=-1) | |
| # return features.squeeze(0).cpu().numpy().astype("float32") | |
| # except Exception as e: | |
| # print(f"Error extracting features: {e}") | |
| # return None | |
| try: | |
| # Convert NumPy array to PIL if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype("uint8")).convert("RGB") | |
| elif isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| inputs = clip_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| features = clip_model.get_image_features(**inputs) | |
| features = torch.nn.functional.normalize(features, p=2, dim=-1) | |
| return features.squeeze(0).cpu().numpy().astype("float32") | |
| except Exception as e: | |
| print(f"Error extracting features: {e}") | |
| return None | |
| def retrieve_similar_captions(image_embedding, k=5): | |
| """ | |
| Retrieve k most similar captions using FAISS index. | |
| Input: Image embedding (numpy array). | |
| Output: List of captions. | |
| """ | |
| if image_embedding.ndim == 1: | |
| image_embedding = image_embedding.reshape(1, -1) | |
| D, I = faiss_index.search(image_embedding, k) | |
| return [captions[i] for i in I[0]] | |
| def extract_location_names(texts): | |
| """ | |
| Extract location names from captions using spaCy. | |
| Input: List of captions. | |
| Output: List of unique location names. | |
| """ | |
| names = [] | |
| for text in texts: | |
| doc = nlp(text) | |
| for ent in doc.ents: | |
| if ent.label_ in ["GPE", "LOC", "FAC"]: | |
| names.append(ent.text) | |
| return list(set(names)) | |
| def generate_caption_from_retrieved(retrieved_captions): | |
| """ | |
| Generate a caption from retrieved captions using T5. | |
| Input: List of retrieved captions. | |
| Output: Generated caption (str). | |
| """ | |
| locations = extract_location_names(retrieved_captions) | |
| location_hint = f"The place might be: {', '.join(locations)}. " if locations else "" | |
| prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| outputs = generator.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=300, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def generate_rag_caption(image): | |
| """ | |
| Generate a RAG-based caption for an image. | |
| Input: PIL Image or image path (str). | |
| Output: Caption (str). | |
| """ | |
| try: | |
| start = time.time() | |
| embedding = extract_image_features(image) | |
| print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s") | |
| if embedding is None: | |
| return "Failed to process image." | |
| start = time.time() | |
| retrieved = retrieve_similar_captions(embedding, k=5) | |
| print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s") | |
| if not retrieved: | |
| return "No similar captions found." | |
| start = time.time() | |
| caption = generate_caption_from_retrieved(retrieved) | |
| print("⏱️ Caption generation:", round(time.time() - start, 2), "s") | |
| return caption | |
| except Exception as e: | |
| print(f"Error in RAG captioning: {e}") | |
| return "Something went wrong during caption generation." |