Spaces:
Runtime error
Runtime error
yahiab
commited on
Commit
Β·
fb8456d
1
Parent(s):
f311e6e
fix
Browse files
app.py
CHANGED
|
@@ -15,22 +15,23 @@ MODEL_LIST = {
|
|
| 15 |
# Global variables
|
| 16 |
current_model = None
|
| 17 |
current_preprocessor = None
|
|
|
|
| 18 |
|
| 19 |
# Load model and preprocessor
|
| 20 |
def load_model_and_preprocessor(model_name):
|
| 21 |
"""Load model and preprocessor for a given model name."""
|
| 22 |
global current_model, current_preprocessor
|
| 23 |
-
print(f"Loading model and preprocessor for: {model_name}")
|
| 24 |
-
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).
|
| 25 |
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
|
| 26 |
-
return f"Model {model_name} loaded successfully."
|
| 27 |
|
| 28 |
# Predict function
|
| 29 |
def predict(image, model, preprocessor):
|
| 30 |
"""Make a prediction on the given image patch using the loaded model."""
|
| 31 |
if model is None or preprocessor is None:
|
| 32 |
raise ValueError("Model and preprocessor are not loaded.")
|
| 33 |
-
inputs = preprocessor(images=image, return_tensors="pt").to(
|
| 34 |
with torch.no_grad():
|
| 35 |
outputs = model(**inputs)
|
| 36 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
|
|
|
| 15 |
# Global variables
|
| 16 |
current_model = None
|
| 17 |
current_preprocessor = None
|
| 18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
|
| 19 |
|
| 20 |
# Load model and preprocessor
|
| 21 |
def load_model_and_preprocessor(model_name):
|
| 22 |
"""Load model and preprocessor for a given model name."""
|
| 23 |
global current_model, current_preprocessor
|
| 24 |
+
print(f"Loading model and preprocessor for: {model_name} on {device}")
|
| 25 |
+
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
|
| 26 |
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
|
| 27 |
+
return f"Model {model_name} loaded successfully on {device}."
|
| 28 |
|
| 29 |
# Predict function
|
| 30 |
def predict(image, model, preprocessor):
|
| 31 |
"""Make a prediction on the given image patch using the loaded model."""
|
| 32 |
if model is None or preprocessor is None:
|
| 33 |
raise ValueError("Model and preprocessor are not loaded.")
|
| 34 |
+
inputs = preprocessor(images=image, return_tensors="pt").to(device)
|
| 35 |
with torch.no_grad():
|
| 36 |
outputs = model(**inputs)
|
| 37 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|