supib4132 commited on
Commit
0d05b51
·
verified ·
1 Parent(s): 6c4bc71

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -13
inference.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from PIL import Image
3
  import torch
4
  from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
@@ -14,7 +13,12 @@ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
  text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
15
  tokenizer = T5Tokenizer.from_pretrained("t5-small")
16
  generator = T5ForConditionalGeneration.from_pretrained("t5-small")
17
- nlp = spacy.load("en_core_web_sm")
 
 
 
 
 
18
 
19
  # Load FAISS index and captions
20
  faiss_index = faiss.read_index("./faiss_index.idx")
@@ -28,7 +32,6 @@ def extract_image_features(image):
28
  Output: Normalized image embedding (numpy array).
29
  """
30
  try:
31
- # Handle both PIL Image and file path
32
  if isinstance(image, str):
33
  image = Image.open(image).convert("RGB")
34
  else:
@@ -98,13 +101,4 @@ def generate_rag_caption(image):
98
  retrieved = retrieve_similar_captions(embedding, k=5)
99
  if not retrieved:
100
  return "No similar captions found."
101
- return generate_caption_from_retrieved(retrieved)
102
-
103
- def predict(image):
104
- """
105
- API-compatible function for inference.
106
- Input: PIL Image or image file path.
107
- Output: Dictionary with caption.
108
- """
109
- caption = generate_rag_caption(image)
110
- return {"caption": caption}
 
 
1
  from PIL import Image
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
 
13
  text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
14
  tokenizer = T5Tokenizer.from_pretrained("t5-small")
15
  generator = T5ForConditionalGeneration.from_pretrained("t5-small")
16
+ try:
17
+ nlp = spacy.load("en_core_web_sm")
18
+ except:
19
+ import spacy.cli
20
+ spacy.cli.download("en_core_web_sm")
21
+ nlp = spacy.load("en_core_web_sm")
22
 
23
  # Load FAISS index and captions
24
  faiss_index = faiss.read_index("./faiss_index.idx")
 
32
  Output: Normalized image embedding (numpy array).
33
  """
34
  try:
 
35
  if isinstance(image, str):
36
  image = Image.open(image).convert("RGB")
37
  else:
 
101
  retrieved = retrieve_similar_captions(embedding, k=5)
102
  if not retrieved:
103
  return "No similar captions found."
104
+ return generate_caption_from_retrieved(retrieved)