fix: adicionar validação antes de chamar classify
Browse files- Verificação do modelo selecionado antes de executar classificação
- Evita erro quando categorias são selecionadas acidentalmente
- Adicionado logo SVG original para o projeto
---
fix: add validation before calling classify
- Check selected model before running classification
- Prevents error when categories are accidentally selected
- Added original SVG logo for the project
- app.py +13 -8
- image/logo.svg +43 -0
app.py
CHANGED
|
@@ -94,6 +94,9 @@ MODELS = [
|
|
| 94 |
]
|
| 95 |
|
| 96 |
def classify(image, model):
|
|
|
|
|
|
|
|
|
|
| 97 |
model_name = model.replace(" << new >>", "")
|
| 98 |
classifier = pipeline("image-classification", model=model_name)
|
| 99 |
result= classifier(image)
|
|
@@ -150,14 +153,16 @@ def main():
|
|
| 150 |
image_to_classify = Image.open(input_image)
|
| 151 |
st.image(image_to_classify, caption="Uploaded Image")
|
| 152 |
if st.button("Classify"):
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
|
|
|
| 94 |
]
|
| 95 |
|
| 96 |
def classify(image, model):
|
| 97 |
+
if model.startswith("--") or model.startswith("#"):
|
| 98 |
+
st.warning("Please select a valid model from the list")
|
| 99 |
+
return []
|
| 100 |
model_name = model.replace(" << new >>", "")
|
| 101 |
classifier = pipeline("image-classification", model=model_name)
|
| 102 |
result= classifier(image)
|
|
|
|
| 153 |
image_to_classify = Image.open(input_image)
|
| 154 |
st.image(image_to_classify, caption="Uploaded Image")
|
| 155 |
if st.button("Classify"):
|
| 156 |
+
if shosen_model.startswith("--") or shosen_model.startswith("#"):
|
| 157 |
+
st.warning("Please select a valid model from the list")
|
| 158 |
+
else:
|
| 159 |
+
image_to_classify = Image.open(input_image)
|
| 160 |
+
classification_obj1=[]
|
| 161 |
+
|
| 162 |
+
classification_result = classify(image_to_classify, shosen_model)
|
| 163 |
+
classification_obj1.append(classification_result)
|
| 164 |
+
print_result(classification_result)
|
| 165 |
+
save_result(classification_result)
|
| 166 |
|
| 167 |
|
| 168 |
if __name__ == "__main__":
|
image/logo.svg
ADDED
|
|