aliSaac510 commited on
Commit
d22cebd
·
verified ·
1 Parent(s): af05991

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +55 -0
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from sentence_transformers import SentenceTransformer
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import uvicorn
7
+
8
+ app = FastAPI(title="Image Embedding API (CLIP)")
9
+
10
+ # Load Models
11
+ print("Loading Models... please wait.")
12
+
13
+ # 1. Image Model: DINOv2 (768 dim)
14
+ img_model_name = 'facebook/dinov2-base'
15
+ img_model = SentenceTransformer(img_model_name)
16
+
17
+ # 2. Text Model: Qwen (Choice: 1.5B or 0.6B for speed/memory)
18
+ # Much stronger than E5, works great on CPU
19
+ text_model_name = 'Alibaba-NLP/gte-Qwen2-1.5b-instruct'
20
+ text_model = SentenceTransformer(text_model_name, trust_remote_code=True)
21
+
22
+ print("All models loaded successfully.")
23
+
24
+ @app.get("/")
25
+ def home():
26
+ return {
27
+ "status": "online",
28
+ "models": {
29
+ "image": img_model_name,
30
+ "text": text_model_name
31
+ }
32
+ }
33
+
34
+ @app.post("/embed/image")
35
+ async def embed_image(image_url: str):
36
+ try:
37
+ response = requests.get(image_url, timeout=10)
38
+ img = Image.open(BytesIO(response.content)).convert("RGB")
39
+ embedding = img_model.encode(img).tolist()
40
+ return {"success": True, "dimension": len(embedding), "embedding": embedding}
41
+ except Exception as e:
42
+ raise HTTPException(status_code=400, detail=str(e))
43
+
44
+ @app.post("/embed/text")
45
+ async def embed_text(text: str):
46
+ try:
47
+ # E5 model requires 'query: ' prefix for similarity tasks
48
+ processed_text = f"query: {text}"
49
+ embedding = text_model.encode(processed_text).tolist()
50
+ return {"success": True, "dimension": len(embedding), "embedding": embedding}
51
+ except Exception as e:
52
+ raise HTTPException(status_code=400, detail=str(e))
53
+
54
+ if __name__ == "__main__":
55
+ uvicorn.run(app, host="0.0.0.0", port=8000)