Sanket17 commited on
Commit
426b1bc
·
verified ·
1 Parent(s): 1f282bf

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +41 -27
models.py CHANGED
@@ -1,27 +1,41 @@
1
- from transformers import AutoProcessor, AutoModelForCausalLM
2
- import torch
3
- from ultralytics import YOLO
4
-
5
- def load_models(device='cpu'):
6
- """Initialize and load all required models."""
7
- # Set default dtype for torch
8
- torch.set_default_dtype(torch.float32)
9
-
10
- yolo_model = YOLO('best.pt').to(device)
11
-
12
- processor = AutoProcessor.from_pretrained(
13
- "microsoft/Florence-2-base",
14
- trust_remote_code=True
15
- )
16
-
17
- caption_model = AutoModelForCausalLM.from_pretrained(
18
- "microsoft/OmniParser/icon_caption_florence",
19
- torch_dtype=torch.float32, # Changed from float16 to float32
20
- trust_remote_code=True
21
- ).to(device)
22
-
23
- return {
24
- 'yolo_model': yolo_model,
25
- 'processor': processor,
26
- 'caption_model': caption_model
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ from ultralytics import YOLO
4
+ import gdown
5
+ import os
6
+ from safetensors.torch import load_file # Safetensors loading method
7
+
8
+ def download_model_from_drive(file_id, destination_path):
9
+ """Download the model from Google Drive using gdown."""
10
+ url = f"https://drive.google.com/uc?id={file_id}"
11
+ gdown.download(url, destination_path, quiet=False)
12
+
13
+ def load_models(device='cpu'):
14
+ """Initialize and load all required models."""
15
+ # Set default dtype for torch
16
+ torch.set_default_dtype(torch.float32)
17
+
18
+ # Download the model from Google Drive (if not already present)
19
+ model_file_path = 'model.safetensors' # Use the correct model file name
20
+ if not os.path.exists(model_file_path):
21
+ file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
22
+ download_model_from_drive(file_id, model_file_path)
23
+
24
+ # Load the YOLO model
25
+ yolo_model = YOLO('best.pt').to(device)
26
+
27
+ # Load processor and caption model
28
+ processor = AutoProcessor.from_pretrained(
29
+ "microsoft/Florence-2-base",
30
+ trust_remote_code=True
31
+ )
32
+
33
+ # Load the caption model from the downloaded .safetensors file
34
+ # Use safetensors library to load the model
35
+ caption_model = load_file(model_file_path, framework="pt", device=device)
36
+
37
+ return {
38
+ 'yolo_model': yolo_model,
39
+ 'processor': processor,
40
+ 'caption_model': caption_model
41
+ }