sohamnk commited on
Commit
276d537
·
verified ·
1 Parent(s): cb8fd55
Files changed (1) hide show
  1. pipeline/__init__.py +51 -0
pipeline/__init__.py CHANGED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from flask import Flask
4
+
5
+ FEATURE_WEIGHTS = {"shape": 0.4, "color": 0.5, "texture": 0.1}
6
+ FINAL_SCORE_THRESHOLD = 0.5
7
+
8
+ # create flask app
9
+ app = Flask(__name__)
10
+
11
+ # load models
12
+ print("="*50)
13
+ print("🚀 Initializing application and loading models...")
14
+ device_name = os.environ.get("device", "cpu")
15
+ device = torch.device('cuda' if 'cuda' in device_name and torch.cuda.is_available() else 'cpu')
16
+ print(f"🧠 Using device: {device}")
17
+
18
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, AutoTokenizer, AutoModel
19
+ from segment_anything import SamPredictor, sam_model_registry
20
+
21
+ print("...Loading Grounding DINO model...")
22
+ gnd_model_id = "IDEA-Research/grounding-dino-tiny"
23
+ processor_gnd = AutoProcessor.from_pretrained(gnd_model_id)
24
+ model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
25
+
26
+ print("...Loading Segment Anything (SAM) model...")
27
+ # IMPORTANT: The path is now relative to the root of the project
28
+ sam_checkpoint = "sam_vit_b_01ec64.pth"
29
+ sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
30
+ predictor = SamPredictor(sam_model)
31
+
32
+ print("...Loading BGE model for text embeddings...")
33
+ bge_model_id = "BAAI/bge-small-en-v1.5"
34
+ tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
35
+ model_text = AutoModel.from_pretrained(bge_model_id).to(device)
36
+
37
+ # Store models in a dictionary to pass to logic functions
38
+ models = {
39
+ "processor_gnd": processor_gnd,
40
+ "model_gnd": model_gnd,
41
+ "predictor": predictor,
42
+ "tokenizer_text": tokenizer_text,
43
+ "model_text": model_text,
44
+ "device": device
45
+ }
46
+
47
+ print("✅ All models loaded successfully.")
48
+ print("="*50)
49
+
50
+ # Import routes after app and models are defined to avoid circular imports
51
+ from pipeline import routes