Spaces:
Sleeping
Sleeping
File size: 4,952 Bytes
2d10e05 4e2f70b 2d10e05 0d05b51 2d10e05 350853d 2800504 350853d 2d10e05 350853d 2d10e05 4e2f70b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import spacy
import time
# Load models and resources
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
generator = T5ForConditionalGeneration.from_pretrained("t5-small")
try:
nlp = spacy.load("en_core_web_sm")
except:
import spacy.cli
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# Load FAISS index and captions
faiss_index = faiss.read_index("./faiss_index.idx")
with open("./captions.json", "r", encoding="utf-8") as f:
captions = json.load(f)
def extract_image_features(image):
"""
Extract image features using CLIP model.
Input: PIL Image or image path (str).
Output: Normalized image embedding (numpy array).
"""
# try:
# if isinstance(image, str):
# image = Image.open(image).convert("RGB")
# else:
# image = image.convert("RGB")
# inputs = clip_processor(images=image, return_tensors="pt")
# with torch.no_grad():
# features = clip_model.get_image_features(**inputs)
# features = torch.nn.functional.normalize(features, p=2, dim=-1)
# return features.squeeze(0).cpu().numpy().astype("float32")
# except Exception as e:
# print(f"Error extracting features: {e}")
# return None
try:
# Convert NumPy array to PIL if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8")).convert("RGB")
elif isinstance(image, str):
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
features = torch.nn.functional.normalize(features, p=2, dim=-1)
return features.squeeze(0).cpu().numpy().astype("float32")
except Exception as e:
print(f"Error extracting features: {e}")
return None
def retrieve_similar_captions(image_embedding, k=5):
"""
Retrieve k most similar captions using FAISS index.
Input: Image embedding (numpy array).
Output: List of captions.
"""
if image_embedding.ndim == 1:
image_embedding = image_embedding.reshape(1, -1)
D, I = faiss_index.search(image_embedding, k)
return [captions[i] for i in I[0]]
def extract_location_names(texts):
"""
Extract location names from captions using spaCy.
Input: List of captions.
Output: List of unique location names.
"""
names = []
for text in texts:
doc = nlp(text)
for ent in doc.ents:
if ent.label_ in ["GPE", "LOC", "FAC"]:
names.append(ent.text)
return list(set(names))
def generate_caption_from_retrieved(retrieved_captions):
"""
Generate a caption from retrieved captions using T5.
Input: List of retrieved captions.
Output: Generated caption (str).
"""
locations = extract_location_names(retrieved_captions)
location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = generator.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=300,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_rag_caption(image):
"""
Generate a RAG-based caption for an image.
Input: PIL Image or image path (str).
Output: Caption (str).
"""
try:
start = time.time()
embedding = extract_image_features(image)
print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s")
if embedding is None:
return "Failed to process image."
start = time.time()
retrieved = retrieve_similar_captions(embedding, k=5)
print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s")
if not retrieved:
return "No similar captions found."
start = time.time()
caption = generate_caption_from_retrieved(retrieved)
print("⏱️ Caption generation:", round(time.time() - start, 2), "s")
return caption
except Exception as e:
print(f"Error in RAG captioning: {e}")
return "Something went wrong during caption generation." |