Sanket17 commited on
Commit
f33b49d
·
verified ·
1 Parent(s): 114e5cf

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +12 -7
models.py CHANGED
@@ -15,24 +15,29 @@ def load_models(device='cpu'):
15
  # Set default dtype for torch
16
  torch.set_default_dtype(torch.float32)
17
 
18
- # Download the model from Google Drive (if not already present)
19
- model_file_path = 'model.safetensors' # Use the correct model file name
20
  if not os.path.exists(model_file_path):
21
  file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
22
  download_model_from_drive(file_id, model_file_path)
23
-
24
  # Load the YOLO model
25
  yolo_model = YOLO('best.pt').to(device)
26
 
27
- # Load processor and caption model
28
  processor = AutoProcessor.from_pretrained(
29
  "microsoft/Florence-2-base",
30
  trust_remote_code=True
31
  )
32
 
33
- # Load the caption model from the downloaded .safetensors file
34
- # Use safetensors library to load the model
35
- caption_model = load_file(model_file_path, framework="pt", device=device)
 
 
 
 
 
36
 
37
  return {
38
  'yolo_model': yolo_model,
 
15
  # Set default dtype for torch
16
  torch.set_default_dtype(torch.float32)
17
 
18
+ # Download the caption model if not present
19
+ model_file_path = 'model.safetensors' # Adjust this name as needed
20
  if not os.path.exists(model_file_path):
21
  file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
22
  download_model_from_drive(file_id, model_file_path)
23
+
24
  # Load the YOLO model
25
  yolo_model = YOLO('best.pt').to(device)
26
 
27
+ # Load processor for the caption model
28
  processor = AutoProcessor.from_pretrained(
29
  "microsoft/Florence-2-base",
30
  trust_remote_code=True
31
  )
32
 
33
+ # Load the caption model
34
+ model_state_dict = load_file(model_file_path) # Load tensors from .safetensors
35
+ caption_model = AutoModelForCausalLM.from_pretrained(
36
+ "microsoft/Florence-2-base",
37
+ trust_remote_code=True
38
+ )
39
+ caption_model.load_state_dict(model_state_dict) # Map tensors to the model
40
+ caption_model.to(device) # Move the model to the correct device
41
 
42
  return {
43
  'yolo_model': yolo_model,