Spaces:
Sleeping
Sleeping
Update api_server.py
Browse files- api_server.py +47 -19
api_server.py
CHANGED
|
@@ -15,6 +15,7 @@ import torch
|
|
| 15 |
from collections import Counter
|
| 16 |
import psutil
|
| 17 |
from gradio_client import Client, handle_file
|
|
|
|
| 18 |
|
| 19 |
# Disable tensorflow warnings
|
| 20 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
@@ -24,7 +25,8 @@ load_type = 'local'
|
|
| 24 |
MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
|
| 25 |
MODEL_DIR = "./artifacts/models"
|
| 26 |
YOLO_DIR = "./artifacts/yolo"
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
# Load the saved YOLO model into memory
|
|
@@ -36,7 +38,7 @@ if load_type == 'local':
|
|
| 36 |
|
| 37 |
model = YOLO(model_path)
|
| 38 |
|
| 39 |
-
print("*****
|
| 40 |
#model.eval() # 設定模型為推理模式
|
| 41 |
elif load_type == 'remote_hub_download':
|
| 42 |
from huggingface_hub import hf_hub_download
|
|
@@ -62,6 +64,18 @@ def image_to_base64(image_path):
|
|
| 62 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
| 63 |
return encoded_string
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# 抓取指定路徑下的所有 JPG 檔案
|
| 67 |
def get_jpg_files(path):
|
|
@@ -117,11 +131,11 @@ def predict():
|
|
| 117 |
except Exception as e:
|
| 118 |
return jsonify({'error': str(e)}), 400
|
| 119 |
|
| 120 |
-
print("*****
|
| 121 |
# Make a prediction using YOLO
|
| 122 |
results = model(image_data)
|
| 123 |
-
print ("===== YOLO predict result:",results,"=====")
|
| 124 |
-
print("***** YOLO predict DONE *****")
|
| 125 |
|
| 126 |
check_memory_usage()
|
| 127 |
|
|
@@ -145,7 +159,7 @@ def predict():
|
|
| 145 |
labels = result.boxes.cls # Get predicted label IDs
|
| 146 |
label_names = [model.names[int(label)] for label in labels] # Convert to names
|
| 147 |
|
| 148 |
-
print(f"====== 3. YOLO label_names: {label_names}======")
|
| 149 |
|
| 150 |
element_counts = Counter(label_names)
|
| 151 |
|
|
@@ -154,15 +168,15 @@ def predict():
|
|
| 154 |
yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
|
| 155 |
yolo_file = get_jpg_files(yolo_path)
|
| 156 |
|
| 157 |
-
print(f"***** 處理:{yolo_path} *****")
|
| 158 |
|
| 159 |
if len(yolo_file) == 0:
|
| 160 |
-
print(f"警告:{element} 沒有找到相關的 JPG 檔案")
|
| 161 |
continue
|
| 162 |
|
| 163 |
for yolo_img in yolo_file: # 每張切圖yolo_img
|
| 164 |
-
print("***** 4. START CLIP *****")
|
| 165 |
-
client = Client(
|
| 166 |
clip_result = client.predict(
|
| 167 |
image=handle_file(yolo_img),
|
| 168 |
top_k=3,
|
|
@@ -171,7 +185,7 @@ def predict():
|
|
| 171 |
top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words)
|
| 172 |
encoded_images.append(image_to_base64(yolo_img))
|
| 173 |
element_list.append(element)
|
| 174 |
-
print(f"===== CLIP result:{top_k_words} =====\n")
|
| 175 |
|
| 176 |
# 建立回應資料
|
| 177 |
response_data = {
|
|
@@ -193,14 +207,28 @@ def predict():
|
|
| 193 |
|
| 194 |
|
| 195 |
# API route for health check
|
| 196 |
-
@app.route('/
|
| 197 |
-
def
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
# API route for version
|
|
|
|
| 15 |
from collections import Counter
|
| 16 |
import psutil
|
| 17 |
from gradio_client import Client, handle_file
|
| 18 |
+
from io import BytesIO
|
| 19 |
|
| 20 |
# Disable tensorflow warnings
|
| 21 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
|
|
| 25 |
MODEL_YOLO = "yolo11_detect_best_241024_1.pt"
|
| 26 |
MODEL_DIR = "./artifacts/models"
|
| 27 |
YOLO_DIR = "./artifacts/yolo"
|
| 28 |
+
IMG2TEXT_URL = "https://fd39e54bcb191a37bf.gradio.live/"
|
| 29 |
+
TEXT2IMG_URL = "https://91698ded8ba92d3bb0.gradio.live/"
|
| 30 |
|
| 31 |
|
| 32 |
# Load the saved YOLO model into memory
|
|
|
|
| 38 |
|
| 39 |
model = YOLO(model_path)
|
| 40 |
|
| 41 |
+
print("***** FLASK API---LOAD YOLO MODEL DONE *****")
|
| 42 |
#model.eval() # 設定模型為推理模式
|
| 43 |
elif load_type == 'remote_hub_download':
|
| 44 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 64 |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
| 65 |
return encoded_string
|
| 66 |
|
| 67 |
+
def convert_webp_to_base64(webp_path):
|
| 68 |
+
# 開啟 .webp 圖片檔
|
| 69 |
+
with Image.open(webp_path) as img:
|
| 70 |
+
# 將圖片存到 BytesIO 物件中,以便轉換為 base64
|
| 71 |
+
buffered = BytesIO()
|
| 72 |
+
img.save(buffered, format="WEBP")
|
| 73 |
+
|
| 74 |
+
# 取得 base64 編碼的字串
|
| 75 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 76 |
+
|
| 77 |
+
return img_base64
|
| 78 |
+
|
| 79 |
|
| 80 |
# 抓取指定路徑下的所有 JPG 檔案
|
| 81 |
def get_jpg_files(path):
|
|
|
|
| 131 |
except Exception as e:
|
| 132 |
return jsonify({'error': str(e)}), 400
|
| 133 |
|
| 134 |
+
print("***** FLASK API---/predict Start YOLO predict *****")
|
| 135 |
# Make a prediction using YOLO
|
| 136 |
results = model(image_data)
|
| 137 |
+
print ("===== FLASK API---/predict YOLO predict result:",results,"=====")
|
| 138 |
+
print("***** FLASK API---/predict YOLO predict DONE *****")
|
| 139 |
|
| 140 |
check_memory_usage()
|
| 141 |
|
|
|
|
| 159 |
labels = result.boxes.cls # Get predicted label IDs
|
| 160 |
label_names = [model.names[int(label)] for label in labels] # Convert to names
|
| 161 |
|
| 162 |
+
print(f"====== FLASK API---/predict 3. YOLO label_names: {label_names}======")
|
| 163 |
|
| 164 |
element_counts = Counter(label_names)
|
| 165 |
|
|
|
|
| 168 |
yolo_path = f"{YOLO_DIR}/{message_id}/{element}"
|
| 169 |
yolo_file = get_jpg_files(yolo_path)
|
| 170 |
|
| 171 |
+
print(f"***** FLASK API---/predict 處理:{yolo_path} *****")
|
| 172 |
|
| 173 |
if len(yolo_file) == 0:
|
| 174 |
+
print(f" FLASK API---/predict 警告:{element} 沒有找到相關的 JPG 檔案")
|
| 175 |
continue
|
| 176 |
|
| 177 |
for yolo_img in yolo_file: # 每張切圖yolo_img
|
| 178 |
+
print("***** FLASK API---/predict 4. START CLIP *****")
|
| 179 |
+
client = Client(IMG2TEXT_URL)
|
| 180 |
clip_result = client.predict(
|
| 181 |
image=handle_file(yolo_img),
|
| 182 |
top_k=3,
|
|
|
|
| 185 |
top_k_words.append(clip_result) # CLIP預測3個結果(top_k_words)
|
| 186 |
encoded_images.append(image_to_base64(yolo_img))
|
| 187 |
element_list.append(element)
|
| 188 |
+
print(f"===== FLASK API---/predict CLIP result:{top_k_words} =====\n")
|
| 189 |
|
| 190 |
# 建立回應資料
|
| 191 |
response_data = {
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
# API route for health check
|
| 210 |
+
@app.route('/text2img', methods=['POST'])
|
| 211 |
+
def text2img():
|
| 212 |
+
text_message = request.form.get('text_message')
|
| 213 |
+
message_id = request.form.get('message_id')
|
| 214 |
+
|
| 215 |
+
client = Client(TEXT2IMG_URL)
|
| 216 |
+
result = client.predict(
|
| 217 |
+
word= text_message,
|
| 218 |
+
api_name="/predict"
|
| 219 |
+
)
|
| 220 |
+
print(f"===== FLASK API---/text2img 文字轉圖片result[0]:{result[0]} =====")
|
| 221 |
+
result_img = convert_webp_to_base64(result[0])
|
| 222 |
+
print(f"===== FLASK API---/text2img 文字轉圖片轉base64:{result_img} =====")
|
| 223 |
+
|
| 224 |
+
# 建立回應資料
|
| 225 |
+
response_data = {
|
| 226 |
+
'message_id': message_id,
|
| 227 |
+
'encoded_image': result_img,
|
| 228 |
+
'description': result[1]
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
return jsonify(response_data), 200
|
| 232 |
|
| 233 |
|
| 234 |
# API route for version
|