Update models.py
Browse files
models.py
CHANGED
|
@@ -15,24 +15,29 @@ def load_models(device='cpu'):
|
|
| 15 |
# Set default dtype for torch
|
| 16 |
torch.set_default_dtype(torch.float32)
|
| 17 |
|
| 18 |
-
# Download the model
|
| 19 |
-
model_file_path = 'model.safetensors' #
|
| 20 |
if not os.path.exists(model_file_path):
|
| 21 |
file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
|
| 22 |
download_model_from_drive(file_id, model_file_path)
|
| 23 |
-
|
| 24 |
# Load the YOLO model
|
| 25 |
yolo_model = YOLO('best.pt').to(device)
|
| 26 |
|
| 27 |
-
# Load processor
|
| 28 |
processor = AutoProcessor.from_pretrained(
|
| 29 |
"microsoft/Florence-2-base",
|
| 30 |
trust_remote_code=True
|
| 31 |
)
|
| 32 |
|
| 33 |
-
# Load the caption model
|
| 34 |
-
|
| 35 |
-
caption_model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
return {
|
| 38 |
'yolo_model': yolo_model,
|
|
|
|
| 15 |
# Set default dtype for torch
|
| 16 |
torch.set_default_dtype(torch.float32)
|
| 17 |
|
| 18 |
+
# Download the caption model if not present
|
| 19 |
+
model_file_path = 'model.safetensors' # Adjust this name as needed
|
| 20 |
if not os.path.exists(model_file_path):
|
| 21 |
file_id = "1hUCqZ3X8mcM-KcwWFjcsFg7PA0hUvE3k" # Replace with your Google Drive file ID
|
| 22 |
download_model_from_drive(file_id, model_file_path)
|
| 23 |
+
|
| 24 |
# Load the YOLO model
|
| 25 |
yolo_model = YOLO('best.pt').to(device)
|
| 26 |
|
| 27 |
+
# Load processor for the caption model
|
| 28 |
processor = AutoProcessor.from_pretrained(
|
| 29 |
"microsoft/Florence-2-base",
|
| 30 |
trust_remote_code=True
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# Load the caption model
|
| 34 |
+
model_state_dict = load_file(model_file_path) # Load tensors from .safetensors
|
| 35 |
+
caption_model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
+
"microsoft/Florence-2-base",
|
| 37 |
+
trust_remote_code=True
|
| 38 |
+
)
|
| 39 |
+
caption_model.load_state_dict(model_state_dict) # Map tensors to the model
|
| 40 |
+
caption_model.to(device) # Move the model to the correct device
|
| 41 |
|
| 42 |
return {
|
| 43 |
'yolo_model': yolo_model,
|