SaniaE commited on
Commit
50e497f
·
verified ·
1 Parent(s): 9fef689

optimized code

Browse files
Files changed (1) hide show
  1. app.py +35 -69
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
- import asyncio
3
  import torch
4
  import random
 
5
  from PIL import Image
6
  from fastapi import FastAPI, UploadFile, File, Query
7
  from fastapi.middleware.cors import CORSMiddleware
@@ -10,7 +10,6 @@ from transformers import (
10
  BlipProcessor, BlipForConditionalGeneration,
11
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
12
  )
13
- from sentence_transformers import SentenceTransformer, util
14
 
15
  app = FastAPI()
16
 
@@ -18,9 +17,8 @@ app = FastAPI()
18
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
  MODELS = {}
21
- SEARCH_MODEL = None
22
 
23
- # We'll map your local folder names to the specific config
24
  MODEL_SETTINGS = {
25
  "blip": {
26
  "subfolder": "blip",
@@ -33,72 +31,54 @@ MODEL_SETTINGS = {
33
  "processor": [ViTImageProcessor, AutoProcessor],
34
  "pretrained_path": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"],
35
  "inference_model": AutoModelForCausalLM
36
- },
37
- "git": {
38
- "subfolder": "git",
39
- "processor": AutoProcessor,
40
- "pretrained_path": "microsoft/git-base",
41
- "inference_model": AutoModelForCausalLM
42
  }
43
  }
44
 
45
  @app.on_event("startup")
46
  async def startup_event():
47
- global MODELS, SEARCH_MODEL
 
 
48
 
49
- # 1. Authenticate and Download from Private Repo
50
- token = os.getenv("HF_Token")
51
- if token:
52
- login(token=token)
53
-
54
- print(f"Downloading ensemble models from {REPO_ID}...")
55
- # This downloads the whole repo into a local 'weights' directory
56
  local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
57
 
58
- # 2. Load Models from the downloaded folders
59
  for name, cfg in MODEL_SETTINGS.items():
60
  ckpt_path = os.path.join(local_dir, cfg["subfolder"])
61
- inf_model = cfg["inference_model"]
62
- pretrained = cfg["pretrained_path"]
63
- proc_class = cfg["processor"]
64
-
65
  print(f"Loading {name} from {ckpt_path}...")
66
- # from_pretrained handles .safetensors automatically
67
- model = inf_model.from_pretrained(ckpt_path).to(DEVICE)
68
 
 
 
 
 
69
  if name == "vit":
70
- i_proc = proc_class[0].from_pretrained(pretrained[0])
71
- t_proc = proc_class[1].from_pretrained(pretrained[1])
72
  processor = (i_proc, t_proc)
73
  else:
74
- processor = proc_class.from_pretrained(pretrained)
75
 
76
  MODELS[name] = {"model": model, "processor": processor}
77
-
78
- SEARCH_MODEL = SentenceTransformer('clip-ViT-B-32')
79
- print("Ensemble is live!")
80
-
81
- async def run_inference(m_name, image, temp, top_k, top_p):
82
- # This runs in a separate thread to avoid blocking the event loop
83
- return await asyncio.to_thread(_generate_sync, m_name, image, temp, top_k, top_p)
84
 
 
85
  def _generate_sync(m_name, image, temp, top_k, top_p):
86
  m_data = MODELS[m_name]
87
  model = m_data["model"]
88
 
89
  if m_name == "vit":
90
  i_proc, t_proc = m_data["processor"]
91
- pixel_values = i_proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
92
  gen_ids = model.generate(
93
- pixel_values=pixel_values, max_length=300, do_sample=True,
94
  temperature=temp, top_k=top_k, top_p=top_p
95
  )
96
  return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
97
  else:
98
  proc = m_data["processor"]
99
- pixel_values = proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
100
  gen_ids = model.generate(
101
- pixel_values=pixel_values, max_length=300, do_sample=True,
102
  temperature=temp, top_k=top_k, top_p=top_p
103
  )
104
  return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
@@ -111,47 +91,33 @@ async def generate_endpoint(
111
  top_p: float = Query(0.9)
112
  ):
113
  image = Image.open(file.file).convert("RGB")
114
- available = list(MODELS.keys())
 
115
  model_selection = random.choices(available, k=5)
116
 
117
- # Create tasks for parallel execution
118
- tasks = [run_inference(m, image, temp, top_k, top_p) for m in model_selection]
119
  captions = await asyncio.gather(*tasks)
120
 
121
  return {"captions": captions, "mix": model_selection}
122
 
123
-
124
  @app.post("/ui-tester")
125
  async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
126
- """Matches a user description against an image using CLIP embeddings."""
127
  image = Image.open(file.file).convert("RGB")
 
128
 
129
- img_emb = SEARCH_MODEL.encode(image)
130
- txt_emb = SEARCH_MODEL.encode(description)
131
 
132
- # Calculate cosine similarity
133
- score = util.cos_sim(img_emb, txt_emb).item()
 
 
 
 
 
134
 
135
  return {
136
  "match_score": round(score, 4),
137
- "is_match": score > 0.25, # Threshold can be adjusted
138
- "status": "High correlation" if score > 0.3 else "Low correlation"
139
- }
140
-
141
- @app.get("/ui-search")
142
- async def ui_search(description: str = Query(...)):
143
- """Returns top image matches from a gallery based on a text description."""
144
- if not IMAGE_GALLERY_EMBEDDINGS:
145
- return {"error": "Gallery not initialized"}
146
-
147
- query_emb = SEARCH_MODEL.encode(description)
148
- hits = util.semantic_search(query_emb, IMAGE_GALLERY_EMBEDDINGS, top_k=3)
149
-
150
- results = []
151
- for hit in hits[0]:
152
- results.append({
153
- "image_path": IMAGE_PATHS[hit['corpus_id']],
154
- "score": round(hit['score'], 4)
155
- })
156
-
157
- return {"results": results}
 
1
  import os
 
2
  import torch
3
  import random
4
+ import asyncio
5
  from PIL import Image
6
  from fastapi import FastAPI, UploadFile, File, Query
7
  from fastapi.middleware.cors import CORSMiddleware
 
10
  BlipProcessor, BlipForConditionalGeneration,
11
  ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
12
  )
 
13
 
14
  app = FastAPI()
15
 
 
17
  REPO_ID = "SaniaE/Image_Captioning_Ensemble"
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  MODELS = {}
 
20
 
21
+ # Removed GIT, kept BLIP and ViT
22
  MODEL_SETTINGS = {
23
  "blip": {
24
  "subfolder": "blip",
 
31
  "processor": [ViTImageProcessor, AutoProcessor],
32
  "pretrained_path": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"],
33
  "inference_model": AutoModelForCausalLM
 
 
 
 
 
 
34
  }
35
  }
36
 
37
  @app.on_event("startup")
38
  async def startup_event():
39
+ global MODELS
40
+ token = os.getenv("HF_TOKEN")
41
+ if token: login(token=token)
42
 
43
+ print(f"Downloading models from {REPO_ID}...")
 
 
 
 
 
 
44
  local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
45
 
 
46
  for name, cfg in MODEL_SETTINGS.items():
47
  ckpt_path = os.path.join(local_dir, cfg["subfolder"])
 
 
 
 
48
  print(f"Loading {name} from {ckpt_path}...")
 
 
49
 
50
+ # Load Model
51
+ model = cfg["inference_model"].from_pretrained(ckpt_path).to(DEVICE)
52
+
53
+ # Load Processor
54
  if name == "vit":
55
+ i_proc = cfg["processor"][0].from_pretrained(cfg["pretrained_path"][0])
56
+ t_proc = cfg["processor"][1].from_pretrained(cfg["pretrained_path"][1])
57
  processor = (i_proc, t_proc)
58
  else:
59
+ processor = cfg["processor"].from_pretrained(cfg["pretrained_path"])
60
 
61
  MODELS[name] = {"model": model, "processor": processor}
62
+ print("Optimization Complete: GIT and Search removed. Ensemble is live!")
 
 
 
 
 
 
63
 
64
+ # --- Helper for Parallel Inference ---
65
  def _generate_sync(m_name, image, temp, top_k, top_p):
66
  m_data = MODELS[m_name]
67
  model = m_data["model"]
68
 
69
  if m_name == "vit":
70
  i_proc, t_proc = m_data["processor"]
71
+ inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
72
  gen_ids = model.generate(
73
+ **inputs, max_length=300, do_sample=True,
74
  temperature=temp, top_k=top_k, top_p=top_p
75
  )
76
  return t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
77
  else:
78
  proc = m_data["processor"]
79
+ inputs = proc(images=image, return_tensors="pt").to(DEVICE)
80
  gen_ids = model.generate(
81
+ **inputs, max_length=300, do_sample=True,
82
  temperature=temp, top_k=top_k, top_p=top_p
83
  )
84
  return proc.batch_decode(gen_ids, skip_special_tokens=True)[0].strip()
 
91
  top_p: float = Query(0.9)
92
  ):
93
  image = Image.open(file.file).convert("RGB")
94
+ available = list(MODELS.keys()) # Only blip and vit
95
+ # Create 5 slots from the 2 remaining models
96
  model_selection = random.choices(available, k=5)
97
 
98
+ tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in model_selection]
 
99
  captions = await asyncio.gather(*tasks)
100
 
101
  return {"captions": captions, "mix": model_selection}
102
 
 
103
  @app.post("/ui-tester")
104
  async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
105
+ """Uses BLIP's native capability to score the match between image and text."""
106
  image = Image.open(file.file).convert("RGB")
107
+ blip_data = MODELS["blip"]
108
 
109
+ # We use the processor to prepare both image and text for the model
110
+ inputs = blip_data["processor"](images=image, text=description, return_tensors="pt").to(DEVICE)
111
 
112
+ with torch.no_grad():
113
+ # BLIP models have a built-in vision/text matching logic
114
+ # For simple captioning models, we can use the model's loss or log-likelihood
115
+ outputs = blip_data["model"](**inputs, labels=inputs["input_ids"])
116
+ # We convert the loss to a pseudo-similarity score (lower loss = higher match)
117
+ loss = outputs.loss.item()
118
+ score = 1 / (1 + loss) # Normalized 0 to 1
119
 
120
  return {
121
  "match_score": round(score, 4),
122
+ "status": "High match" if score > 0.4 else "Low match"
123
+ }