aliSaac510 commited on
Commit
c5cba2b
·
verified ·
1 Parent(s): 11b1db9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -10
main.py CHANGED
@@ -1,21 +1,24 @@
1
  from fastapi import FastAPI, HTTPException
2
  from sentence_transformers import SentenceTransformer
 
 
3
  from PIL import Image
4
  import requests
5
  from io import BytesIO
6
  import uvicorn
7
 
8
- app = FastAPI(title="Image Embedding API (CLIP)")
9
 
10
  # Load Models
11
  print("Loading Models... please wait.")
12
 
13
- # 1. Image Model: DINOv2 (768 dim)
14
- img_model_name = 'facebook/dinov2-base'
15
- img_model = SentenceTransformer(img_model_name)
 
 
16
 
17
- # 2. Text Model: Qwen (Choice: 1.5B or 0.6B for speed/memory)
18
- # Much stronger than E5, works great on CPU
19
  text_model_name = 'Alibaba-NLP/gte-Qwen2-1.5b-instruct'
20
  text_model = SentenceTransformer(text_model_name, trust_remote_code=True)
21
 
@@ -26,7 +29,7 @@ def home():
26
  return {
27
  "status": "online",
28
  "models": {
29
- "image": img_model_name,
30
  "text": text_model_name
31
  }
32
  }
@@ -36,7 +39,16 @@ async def embed_image(image_url: str):
36
  try:
37
  response = requests.get(image_url, timeout=10)
38
  img = Image.open(BytesIO(response.content)).convert("RGB")
39
- embedding = img_model.encode(img).tolist()
 
 
 
 
 
 
 
 
 
40
  return {"success": True, "dimension": len(embedding), "embedding": embedding}
41
  except Exception as e:
42
  raise HTTPException(status_code=400, detail=str(e))
@@ -44,7 +56,7 @@ async def embed_image(image_url: str):
44
  @app.post("/embed/text")
45
  async def embed_text(text: str):
46
  try:
47
- # E5 model requires 'query: ' prefix for similarity tasks
48
  processed_text = f"query: {text}"
49
  embedding = text_model.encode(processed_text).tolist()
50
  return {"success": True, "dimension": len(embedding), "embedding": embedding}
@@ -52,4 +64,4 @@ async def embed_text(text: str):
52
  raise HTTPException(status_code=400, detail=str(e))
53
 
54
  if __name__ == "__main__":
55
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  from fastapi import FastAPI, HTTPException
2
  from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoImageProcessor, AutoModel
4
+ import torch
5
  from PIL import Image
6
  import requests
7
  from io import BytesIO
8
  import uvicorn
9
 
10
+ app = FastAPI(title="Movie Linker AI API")
11
 
12
  # Load Models
13
  print("Loading Models... please wait.")
14
 
15
+ # 1. Image Model: DINOv2 (using transformers directly for stability)
16
+ img_model_id = 'facebook/dinov2-base'
17
+ img_processor = AutoImageProcessor.from_pretrained(img_model_id)
18
+ img_model = AutoModel.from_pretrained(img_model_id)
19
+ img_model.eval() # Set to evaluation mode
20
 
21
+ # 2. Text Model: Qwen (Choice: 1.5B or 0.6B)
 
22
  text_model_name = 'Alibaba-NLP/gte-Qwen2-1.5b-instruct'
23
  text_model = SentenceTransformer(text_model_name, trust_remote_code=True)
24
 
 
29
  return {
30
  "status": "online",
31
  "models": {
32
+ "image": img_model_id,
33
  "text": text_model_name
34
  }
35
  }
 
39
  try:
40
  response = requests.get(image_url, timeout=10)
41
  img = Image.open(BytesIO(response.content)).convert("RGB")
42
+
43
+ # Process image for DINOv2
44
+ inputs = img_processor(images=img, return_tensors="pt")
45
+
46
+ with torch.no_grad():
47
+ outputs = img_model(**inputs)
48
+ # DINOv2 uses the CLS token (first token) for the global representation
49
+ # This is available in last_hidden_state[:, 0, :]
50
+ embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
51
+
52
  return {"success": True, "dimension": len(embedding), "embedding": embedding}
53
  except Exception as e:
54
  raise HTTPException(status_code=400, detail=str(e))
 
56
  @app.post("/embed/text")
57
  async def embed_text(text: str):
58
  try:
59
+ # Instruction-tuned models like Qwen work best with prompts
60
  processed_text = f"query: {text}"
61
  embedding = text_model.encode(processed_text).tolist()
62
  return {"success": True, "dimension": len(embedding), "embedding": embedding}
 
64
  raise HTTPException(status_code=400, detail=str(e))
65
 
66
  if __name__ == "__main__":
67
+ uvicorn.run(app, host="0.0.0.0", port=7860)