omar100abdelaal commited on
Commit
9db5a93
·
verified ·
1 Parent(s): 859c5c8

Upload ai_service.py

Browse files
Files changed (1) hide show
  1. ai_service.py +111 -0
ai_service.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
+ from typing import Optional
3
+ from sentence_transformers import SentenceTransformer
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from PIL import Image
6
+ import torch
7
+ import io
8
+
9
+ app = FastAPI(title="AI Embedding Service")
10
+
11
+ class ModelLoader:
12
+ def __init__(self):
13
+ self._text_model = None
14
+ self._clip_model = None
15
+ self._clip_processor = None
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ @property
19
+ def text_model(self):
20
+ if self._text_model is None:
21
+ print("Loading text model (lazy initialization)...")
22
+ self._text_model = SentenceTransformer("BAAI/bge-large-en")
23
+ return self._text_model
24
+
25
+ @property
26
+ def clip_model(self):
27
+ if self._clip_model is None:
28
+ print("Loading image model (lazy initialization)...")
29
+ # Load in fp16 to save memory, especially for Hugging Face Spaces
30
+ self._clip_model = CLIPModel.from_pretrained(
31
+ "openai/clip-vit-large-patch14",
32
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
33
+ ).to(self.device)
34
+ return self._clip_model
35
+
36
+ @property
37
+ def clip_processor(self):
38
+ if self._clip_processor is None:
39
+ self._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
40
+ return self._clip_processor
41
+
42
+ models = ModelLoader()
43
+
44
+ @app.get("/health")
45
+ async def health():
46
+ return {
47
+ "status": "healthy",
48
+ "text_model_loaded": models._text_model is not None,
49
+ "image_model_loaded": models._clip_model is not None,
50
+ "device": models.device
51
+ }
52
+
53
+ @app.post("/embed")
54
+ async def embed(
55
+ property_name: Optional[str] = Form(None),
56
+ description: Optional[str] = Form(None),
57
+ images: Optional[list[UploadFile]] = File(None)
58
+ ):
59
+ response_data = {}
60
+
61
+ # Process Property Name
62
+ if property_name and property_name.strip():
63
+ vec_name = models.text_model.encode(property_name, normalize_embeddings=True)
64
+ response_data["property_name_vector"] = vec_name.tolist()
65
+
66
+ # Process Description
67
+ if description and description.strip():
68
+ vec_desc = models.text_model.encode(description, normalize_embeddings=True)
69
+ response_data["description_vector"] = vec_desc.tolist()
70
+
71
+ # Process Multiple Images
72
+ if images:
73
+ image_vectors = []
74
+ for image in images:
75
+ if not image.filename:
76
+ continue
77
+
78
+ contents = await image.read()
79
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
80
+
81
+ inputs = models.clip_processor(images=img, return_tensors="pt").to(models.device)
82
+
83
+ with torch.no_grad():
84
+ outputs = models.clip_model.get_image_features(**inputs)
85
+
86
+ # Extract tensor depending on transformers output format
87
+ if isinstance(outputs, torch.Tensor):
88
+ features = outputs
89
+ elif hasattr(outputs, "image_embeds") and outputs.image_embeds is not None:
90
+ features = outputs.image_embeds
91
+ elif hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
92
+ features = outputs.pooler_output
93
+ else:
94
+ features = outputs[0]
95
+
96
+ # Apply L2 Normalization for Cosine Similarity
97
+ normalized_features = torch.nn.functional.normalize(features, p=2, dim=-1)
98
+ vec_img = normalized_features.squeeze().tolist()
99
+ image_vectors.append(vec_img)
100
+
101
+ if image_vectors:
102
+ response_data["image_vectors"] = image_vectors
103
+
104
+ if not response_data:
105
+ raise HTTPException(status_code=400, detail="Must provide at least one of property_name, description, or images")
106
+
107
+ return response_data
108
+
109
+ if __name__ == "__main__":
110
+ import uvicorn
111
+ uvicorn.run(app, host="0.0.0.0", port=8000)