File size: 4,952 Bytes
2d10e05
 
 
 
 
 
 
 
4e2f70b
2d10e05
 
 
 
 
 
 
0d05b51
 
 
 
 
 
2d10e05
 
 
 
 
 
 
 
 
 
 
 
350853d
 
 
 
 
 
 
 
 
 
 
 
 
2800504
350853d
 
 
 
2d10e05
 
 
350853d
2d10e05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e2f70b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import spacy
import time

# Load models and resources
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
generator = T5ForConditionalGeneration.from_pretrained("t5-small")
try:
    nlp = spacy.load("en_core_web_sm")
except:
    import spacy.cli
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# Load FAISS index and captions
faiss_index = faiss.read_index("./faiss_index.idx")
with open("./captions.json", "r", encoding="utf-8") as f:
    captions = json.load(f)

def extract_image_features(image):
    """
    Extract image features using CLIP model.
    Input: PIL Image or image path (str).
    Output: Normalized image embedding (numpy array).
    """
    # try:
    #     if isinstance(image, str):
    #         image = Image.open(image).convert("RGB")
    #     else:
    #         image = image.convert("RGB")
    #     inputs = clip_processor(images=image, return_tensors="pt")
    #     with torch.no_grad():
    #         features = clip_model.get_image_features(**inputs)
    #     features = torch.nn.functional.normalize(features, p=2, dim=-1)
    #     return features.squeeze(0).cpu().numpy().astype("float32")
    # except Exception as e:
    #     print(f"Error extracting features: {e}")
    #     return None
    try:
        # Convert NumPy array to PIL if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype("uint8")).convert("RGB")
        elif isinstance(image, str):
            image = Image.open(image).convert("RGB")
        else:
            image = image.convert("RGB")

        inputs = clip_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            features = clip_model.get_image_features(**inputs)
        features = torch.nn.functional.normalize(features, p=2, dim=-1)
        return features.squeeze(0).cpu().numpy().astype("float32")
    except Exception as e:
        print(f"Error extracting features: {e}")
        return None

def retrieve_similar_captions(image_embedding, k=5):
    """
    Retrieve k most similar captions using FAISS index.
    Input: Image embedding (numpy array).
    Output: List of captions.
    """
    if image_embedding.ndim == 1:
        image_embedding = image_embedding.reshape(1, -1)
    D, I = faiss_index.search(image_embedding, k)
    return [captions[i] for i in I[0]]

def extract_location_names(texts):
    """
    Extract location names from captions using spaCy.
    Input: List of captions.
    Output: List of unique location names.
    """
    names = []
    for text in texts:
        doc = nlp(text)
        for ent in doc.ents:
            if ent.label_ in ["GPE", "LOC", "FAC"]:
                names.append(ent.text)
    return list(set(names))

def generate_caption_from_retrieved(retrieved_captions):
    """
    Generate a caption from retrieved captions using T5.
    Input: List of retrieved captions.
    Output: Generated caption (str).
    """
    locations = extract_location_names(retrieved_captions)
    location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
    prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = generator.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_length=300,
        num_beams=5,
        early_stopping=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_rag_caption(image):
    """
    Generate a RAG-based caption for an image.
    Input: PIL Image or image path (str).
    Output: Caption (str).
    """
    try:
        start = time.time()
        embedding = extract_image_features(image)
        print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s")
        if embedding is None:
            return "Failed to process image."

        start = time.time()
        retrieved = retrieve_similar_captions(embedding, k=5)
        print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s")
        if not retrieved:
            return "No similar captions found."

        start = time.time()
        caption = generate_caption_from_retrieved(retrieved)
        print("⏱️ Caption generation:", round(time.time() - start, 2), "s")

        return caption
    except Exception as e:
        print(f"Error in RAG captioning: {e}")
        return "Something went wrong during caption generation."