| import tempfile | |
| import requests | |
| from src.bioclip.predict import TreeOfLifeClassifier, Rank | |
| import logging | |
| class PredictService: | |
| def __init__(self): | |
| self.classifier = TreeOfLifeClassifier() | |
| log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
| logging.basicConfig(level=logging.INFO, format=log_format) | |
| self.logger = logging.getLogger() | |
| def download_image(self, url): | |
| self.logger.info(f'download_image({url})') | |
| response = requests.get(url) | |
| # Vérifier si la requête a réussi | |
| if response.status_code == 200: | |
| # Créer un fichier temporaire | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
| # Écrire le contenu de l'image dans le fichier temporaire | |
| temp_file.write(response.content) | |
| temp_file.close() | |
| # Retourner le chemin du fichier temporaire | |
| return temp_file.name | |
| else: | |
| raise Exception("Error while downloading image. Status: {}".format(response.status_code)) | |
| def predict(self, image_url=None): | |
| if image_url is None: | |
| raise Exception("expect image url") | |
| image_path = self.download_image(image_url) | |
| predictions = self.classifier.predict(image_path, Rank.SPECIES) | |
| for prediction in predictions: | |
| if 'file_name' in prediction: | |
| del prediction['file_name'] | |
| return predictions | |