Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from pathlib import Path | |
| from ultralytics import YOLO | |
| import io | |
| import base64 | |
| import uuid | |
| import glob | |
| from tensorflow import keras | |
| from flask import Flask, jsonify, request, render_template, send_file | |
| import torch | |
| from collections import Counter | |
| import psutil | |
| from gradio_client import Client, handle_file | |
| from io import BytesIO | |
| # Disable tensorflow warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| load_type = 'local' | |
| MODEL_YOLO = "yolo11_detect_best_241024_1.pt" | |
| MODEL_DIR = "./artifacts/models" | |
| YOLO_DIR = "./artifacts/yolo" | |
| GRADIO_URL = "https://a0c594662477a008f4.gradio.live/" | |
| # Load the saved YOLO model into memory | |
| if load_type == 'local': | |
| # 本地模型路徑 | |
| model_path = f'{MODEL_DIR}/{MODEL_YOLO}' | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at {model_path}") | |
| model = YOLO(model_path) | |
| print("***** FLASK API---LOAD YOLO MODEL DONE *****") | |
| #model.eval() # 設定模型為推理模式 | |
| elif load_type == 'remote_hub_download': | |
| from huggingface_hub import hf_hub_download | |
| # 從 Hugging Face Hub 下載模型 | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_YOLO) | |
| model = torch.load(model_path) | |
| #model.eval() | |
| elif load_type == 'remote_hub_from_pretrained': | |
| # 使用 Hugging Face Hub 預訓練的模型方式下載 | |
| os.environ['TRANSFORMERS_CACHE'] = str(Path(MODEL_DIR).absolute()) | |
| from huggingface_hub import from_pretrained | |
| model = from_pretrained(REPO_ID, filename=MODEL_YOLO, cache_dir=MODEL_DIR) | |
| #model.eval() | |
| else: | |
| raise AssertionError('No load type is specified!') | |
| # image to base64 | |
| def image_to_base64(image_path): | |
| with open(image_path, "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
| return encoded_string | |
| # 抓取指定路徑下的所有 JPG 檔案 | |
| def get_jpg_files(path): | |
| """ | |
| Args: | |
| path: 要搜尋的目錄路徑。 | |
| Returns: | |
| 一個包含所有 JPG 檔案路徑的列表。 | |
| """ | |
| return glob.glob(os.path.join(path, "*.jpg")) | |
| # 使用範例 | |
| # image_folder = '/content/drive/MyDrive/chiikawa' # 替換成你的目錄路徑 | |
| # jpg_files = get_jpg_files(image_folder) | |
| def clip_model (choice="find_similar_words",image=None,word=None): | |
| client = Client(GRADIO_URL) | |
| # 當 image 存在時才處理 | |
| if image is not None: | |
| image_input = handle_file(image) | |
| else: | |
| image_input = None | |
| try: | |
| clip_result = client.predict( | |
| choice=choice, | |
| image=image_input, | |
| word=word, | |
| top_k=3, | |
| api_name="/run_function" | |
| ) | |
| except Exception as e: | |
| return f"Error occurred while processing the request: {e}" | |
| return clip_result | |
| def check_memory_usage(): | |
| # Get memory details | |
| memory_info = psutil.virtual_memory() | |
| total_memory = memory_info.total / (1024 * 1024) # Convert bytes to MB | |
| available_memory = memory_info.available / (1024 * 1024) | |
| used_memory = memory_info.used / (1024 * 1024) | |
| memory_usage_percent = memory_info.percent | |
| print(f"^^^^^^ Total Memory: {total_memory:.2f} MB ^^^^^^") | |
| print(f"^^^^^^ Available Memory: {available_memory:.2f} MB ^^^^^^") | |
| print(f"^^^^^^ Used Memory: {used_memory:.2f} MB ^^^^^^") | |
| print(f"^^^^^^ Memory Usage (%): {memory_usage_percent}% ^^^^^^") | |
| # Run the function | |
| check_memory_usage() | |
| # Initialize the Flask application | |
| app = Flask(__name__) | |
| # API route for prediction(YOLO) | |
| def predict(): | |
| #user_id = request.args.get('user_id') | |
| file = request.files['image'] | |
| message_id = request.form.get('message_id') #str(uuid.uuid4()) | |
| choice = request.form.get('choice') | |
| word = request.form.get('word') | |
| if 'image' not in request.files: | |
| # Handle if no file is selected | |
| return jsonify({"error": "No image part"}), 400 | |
| # 讀取圖像 | |
| try: | |
| image_data = Image.open(file) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 400 | |
| print("***** FLASK API---/predict Start YOLO predict *****") | |
| # Make a prediction using YOLO | |
| results = model(image_data) | |
| print ("===== FLASK API---/predict YOLO predict result:",results,"=====") | |
| print("***** FLASK API---/predict YOLO predict DONE *****") | |
| check_memory_usage() | |
| # 檢查 YOLO 是否返回了有效的結果 | |
| if results is None or len(results) == 0: | |
| return jsonify({'error': 'No results from YOLO model'}), 400 | |
| saved_images = [] | |
| # 儲存辨識後的圖片到指定資料夾 | |
| for result in results: | |
| encoded_images=[] | |
| element_list =[] | |
| top_k_words =[] | |
| # 保存圖片 | |
| result.save_crop(f"{YOLO_DIR}/{message_id}") | |
| num_detections = len(result.boxes) # Get the number of detections | |
| labels = result.boxes.cls # Get predicted label IDs | |
| label_names = [model.names[int(label)] for label in labels] # Convert to names | |
| print(f"====== FLASK API---/predict 3. YOLO label_names: {label_names}======") | |
| element_counts = Counter(label_names) | |
| for element, count in element_counts.items(): | |
| yolo_path = f"{YOLO_DIR}/{message_id}/{element}" | |
| yolo_file = get_jpg_files(yolo_path) | |
| print(f"***** FLASK API---/predict 處理:{yolo_path} *****") | |
| if len(yolo_file) == 0: | |
| print(f" FLASK API---/predict 警告:{element} 沒有找到相關的 JPG 檔案") | |
| continue | |
| for yolo_img in yolo_file: # 每張切圖yolo_img | |
| print("***** FLASK API---/predict 4. START CLIP *****") | |
| clip_result = clip_model(choice,yolo_img,word) | |
| top_k_words.append(clip_result[0]) # CLIP預測3個結果(top_k_words) | |
| encoded_images.append(image_to_base64(yolo_img)) | |
| element_list.append(element) | |
| print(f"===== FLASK API---/predict CLIP RESULT:{top_k_words} =====\n") | |
| # 刪除已處理的圖片文件 | |
| print(f"===== FLASK API---/predict DELETE yolo_img:{yolo_img} =====\n") | |
| os.remove(yolo_img) | |
| # 建立回應資料 | |
| response_data = { | |
| 'message_id': message_id, | |
| 'objects': [ | |
| { | |
| 'element': element, | |
| 'images': | |
| { | |
| 'encoded_image': encoded_image, | |
| 'description_list': description_list | |
| } | |
| } | |
| for element, encoded_image, description_list in zip(element_list, encoded_images, top_k_words) | |
| ] | |
| } | |
| return jsonify(response_data), 200 | |
| # API route for health check | |
| def text2img(): | |
| message_id = request.form.get('message_id') | |
| choice = request.form.get('choice') | |
| word = request.form.get('word') | |
| clip_result = clip_model(choice,None,word) | |
| print(f"===== FLASK API---/text2img 文字轉圖片result:{clip_result} =====") | |
| result_img = clip_result[2] # 已經是base64 coded | |
| # 建立回應資料 | |
| response_data = { | |
| 'message_id': message_id, | |
| 'encoded_image': result_img, | |
| 'description': clip_result[0] | |
| } | |
| return jsonify(response_data), 200 | |
| # API route for version | |
| def version(): | |
| """ | |
| Returns the version of the application. | |
| Demo Usage: "curl http://127.0.0.1:5000/version" or using alias "curl http://127.0.0.1:5000/version" | |
| """ | |
| return '1.0' | |
| def hello_world(): | |
| return render_template("index.html") | |
| # return "<p>Hello, Team!</p>" | |
| # Start the Flask application | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |