{{Your Name}} commited on
Commit
d25d814
·
1 Parent(s): 685ba82
Files changed (5) hide show
  1. .gitignore +4 -1
  2. app.py +167 -40
  3. requirements.txt +1 -1
  4. templates/index.html +43 -0
  5. test_api.py +102 -0
.gitignore CHANGED
@@ -49,4 +49,7 @@ Thumbs.db
49
 
50
  # --- Log Files ---
51
  # Ignore log files, which can become large and are specific to a run.
52
- *.log
 
 
 
 
49
 
50
  # --- Log Files ---
51
  # Ignore log files, which can become large and are specific to a run.
52
+ *.log
53
+ *.jpg
54
+ *.jpeg
55
+ *.png
app.py CHANGED
@@ -1,57 +1,184 @@
 
 
 
1
  import numpy as np
 
 
 
 
 
 
 
2
  import cv2
3
- import gradio as gr
4
- from ultralytics import YOLO
5
- import os
6
  from huggingface_hub import hf_hub_download
7
- from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
9
 
10
- print("Loading environment variables...")
11
- load_dotenv()
12
 
13
- HF_REPO_ID = "tententgc/Iskyn"
14
- MODEL_FILENAME = "best.onnx"
 
 
 
 
 
15
 
16
- print(f"Downloading '{MODEL_FILENAME}' from '{HF_REPO_ID}'...")
 
 
 
 
 
17
 
18
- model_path = hf_hub_download(
19
- repo_id=HF_REPO_ID,
20
- filename=MODEL_FILENAME,
21
- token=os.getenv("HF_TOKEN")
22
- )
 
23
 
24
- print(f"Model downloaded to: {model_path}")
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
27
 
28
- print("Loading YOLO model...")
29
- onnx_model = YOLO(model_path)
30
- print("Model loaded successfully.")
 
 
 
 
 
31
 
 
 
 
 
32
 
 
 
 
 
33
 
34
- def predict_image(image_filepath, conf_threshold, iou_threshold):
35
- results = onnx_model.predict(
36
- image_filepath,
37
- conf=conf_threshold,
38
- iou=iou_threshold
39
- )
40
- result = results[0]
41
- im_array = result.plot()
42
- im_rgb = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB)
43
- return im_rgb
44
 
45
- iface = gr.Interface(
46
- fn=predict_image,
47
- inputs=[
48
- gr.Image(type="filepath", label="Upload Image"),
49
- gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"),
50
- gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold")
51
- ],
52
- outputs=gr.Image(type="numpy", label="Result"),
53
- title="Detection Face Skin",
54
- description="Upload an image and adjust the thresholds to fine-tune detection."
55
- )
56
 
57
- iface.launch()
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import uvicorn
4
  import numpy as np
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from typing import List
7
+ from fastapi import FastAPI, UploadFile, File, Request, Form
8
+ from fastapi.responses import HTMLResponse
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi.templating import Jinja2Templates
11
+ import onnxruntime as ort
12
  import cv2
 
 
 
13
  from huggingface_hub import hf_hub_download
14
+ import os
15
+ import uuid
16
+
17
+ # --- FastAPI and Template Setup ---
18
+ app = FastAPI(title="YOLOv8 ONNX Object Detection Demo")
19
+
20
+ # Mount a static directory to serve saved images
21
+ app.mount("/static", StaticFiles(directory="static"), name="static")
22
+
23
+ templates = Jinja2Templates(directory="templates")
24
+
25
+ # --- Model Loading and Configuration ---
26
+ # Download the ONNX model file and get its path
27
+ try:
28
+ onnx_model_path = hf_hub_download(repo_id="tententgc/Iskyn", filename="best.onnx")
29
+ session = ort.InferenceSession(onnx_model_path)
30
+ print("ONNX model loaded successfully.")
31
+ except Exception as e:
32
+ print(f"Failed to load ONNX model: {e}")
33
+ session = None
34
+
35
+ if session:
36
+ input_name = session.get_inputs()[0].name
37
+ output_names = [output.name for output in session.get_outputs()]
38
+ input_shape = session.get_inputs()[0].shape[2:] # Get the expected image size
39
+ else:
40
+ input_name = None
41
+ output_names = []
42
+ input_shape = (640, 640) # Default size if model fails to load
43
+
44
+ # Define the class names for your model
45
+ # IMPORTANT: Update this with the actual class names your model was trained on
46
+ CLASSES = [
47
+ "melasma", "acne", "wrinkle"
48
+ ]
49
+
50
+ # A dictionary to map class names to colors for plotting
51
+ COLORS = {
52
+ "melasma": "red",
53
+ "acne": "green",
54
+ "wrinkle": "blue",
55
+ # Add more classes and colors as needed
56
+ }
57
+
58
+ # --- Helper Functions ---
59
+ def preprocess_image(image: Image.Image, size: tuple) -> np.ndarray:
60
+ """Preprocesses an image for model inference."""
61
+ image = image.resize(size)
62
+ image = np.array(image)
63
+ image = image.transpose(2, 0, 1) # HWC to CHW
64
+ image = np.expand_dims(image, axis=0) # Add batch dimension
65
+ image = image.astype(np.float32) / 255.0 # Normalize
66
+ return image
67
+
68
+ def postprocess_output(output, original_size, input_shape, conf_threshold=0.25, iou_threshold=0.45):
69
+ """Post-processes the model output to get bounding boxes, scores, and class IDs."""
70
+ output = np.squeeze(output).T
71
+ scores = np.max(output[:, 4:], axis=1)
72
+ filtered_indices = scores > conf_threshold
73
+ output = output[filtered_indices]
74
+ scores = scores[filtered_indices]
75
+
76
+ if not len(output):
77
+ return []
78
 
79
+ boxes = output[:, :4]
80
+ boxes[:, 0] -= boxes[:, 2] / 2
81
+ boxes[:, 1] -= boxes[:, 3] / 2
82
+ boxes[:, 2] += boxes[:, 0]
83
+ boxes[:, 3] += boxes[:, 1]
84
 
85
+ class_ids = np.argmax(output[:, 4:], axis=1)
86
+ indices = cv2.dnn.NMSBoxes(boxes.astype(np.int32), scores.astype(np.float32), conf_threshold, iou_threshold)
87
 
88
+ detections = []
89
+ if len(indices) > 0:
90
+ for i in indices.flatten():
91
+ box = boxes[i]
92
+ x1, y1, x2, y2 = box.astype(int)
93
+ class_id = class_ids[i]
94
+ score = scores[i]
95
 
96
+ original_width, original_height = original_size
97
+ resized_width, resized_height = input_shape
98
+ x1 = int(x1 * original_width / resized_width)
99
+ y1 = int(y1 * original_height / resized_height)
100
+ x2 = int(x2 * original_width / resized_width)
101
+ y2 = int(y2 * original_height / resized_height)
102
 
103
+ detections.append({
104
+ "class_name": CLASSES[class_id],
105
+ "confidence": float(score),
106
+ "box": [x1, y1, x2, y2]
107
+ })
108
+ return detections
109
 
110
+ def draw_boxes_on_image(image, detections):
111
+ """Draws bounding boxes, class names, and confidence scores on an image."""
112
+ draw = ImageDraw.Draw(image)
113
+ try:
114
+ font = ImageFont.truetype("arial.ttf", 30)
115
+ except IOError:
116
+ font = ImageFont.load_default()
117
+ print("Arial font not found, using default font.")
118
 
119
+ for detection in detections:
120
+ box = detection['box']
121
+ class_name = detection['class_name']
122
+ confidence = detection['confidence']
123
+
124
+ color = COLORS.get(class_name, "white")
125
+ draw.rectangle(box, outline=color, width=3)
126
+
127
+ label = f"{class_name}: {confidence:.2f}"
128
+
129
+ # Use textbbox() to get text dimensions
130
+ text_x, text_y, text_width, text_height = draw.textbbox((0, 0), label, font=font)
131
+
132
+ # Position text slightly above the top-left corner
133
+ text_position_y = box[1] - text_height - 5
134
+ if text_position_y < 0:
135
+ text_position_y = box[1] + 5 # Draw below if not enough space above
136
+
137
+ draw.rectangle([box[0], text_position_y, box[0] + text_width, text_position_y + text_height], fill=color)
138
+ draw.text((box[0], text_position_y), label, fill="black", font=font)
139
+ return image
140
 
141
+ # --- FastAPI Endpoints ---
142
+ @app.get("/", response_class=HTMLResponse)
143
+ async def read_root(request: Request):
144
+ """Serve the HTML interface."""
145
+ return templates.TemplateResponse("index.html", {"request": request, "image_url": None, "error_message": None})
146
 
147
+ @app.post("/predict_web", response_class=HTMLResponse)
148
+ async def predict_web(request: Request, file: UploadFile = File(...)):
149
+ """Handle image upload, run detection, and return plotted image."""
150
+ if not session:
151
+ return templates.TemplateResponse("index.html", {"request": request, "error_message": "ONNX model not loaded."})
152
+
153
+ if not file.content_type.startswith("image/"):
154
+ return templates.TemplateResponse("index.html", {"request": request, "error_message": "Invalid file type. Please upload an image."})
155
 
156
+ try:
157
+ image_data = await file.read()
158
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
159
+ original_size = image.size
160
 
161
+ # Preprocess, run inference, and post-process
162
+ preprocessed_image = preprocess_image(image, size=input_shape)
163
+ outputs = session.run(output_names, {input_name: preprocessed_image})
164
+ detections = postprocess_output(outputs, original_size, input_shape)
165
 
166
+ # Draw boxes on the original image
167
+ plotted_image = draw_boxes_on_image(image.copy(), detections)
 
 
 
 
 
 
 
 
168
 
169
+ # Create a unique filename and save the plotted image
170
+ unique_filename = f"{uuid.uuid4()}.jpg"
171
+ output_image_path = os.path.join("static", "output", unique_filename)
172
+ plotted_image.save(output_image_path)
173
+
174
+ image_url = f"/static/output/{unique_filename}"
175
+
176
+ return templates.TemplateResponse("index.html", {"request": request, "image_url": image_url})
177
+
178
+ except Exception as e:
179
+ return templates.TemplateResponse("index.html", {"request": request, "error_message": f"An error occurred: {e}"})
180
 
181
+ if __name__ == "__main__":
182
+ # Create the static/output directory if it doesn't exist
183
+ os.makedirs(os.path.join("static", "output"), exist_ok=True)
184
+ uvicorn.run(app, host="127.0.0.1", port=8000)
requirements.txt CHANGED
@@ -47,7 +47,7 @@ h11==0.16.0
47
  hf-xet==1.1.10
48
  httpcore==1.0.9
49
  httpx==0.28.1
50
- huggingface-hub==0.34.4
51
  humanfriendly==10.0
52
  idna==3.10
53
  ifaddr==0.2.0
 
47
  hf-xet==1.1.10
48
  httpcore==1.0.9
49
  httpx==0.28.1
50
+ huggingface-hub==0.35.0
51
  humanfriendly==10.0
52
  idna==3.10
53
  ifaddr==0.2.0
templates/index.html ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>YOLO Object Detection Demo</title>
7
+ <style>
8
+ body { font-family: Arial, sans-serif; padding: 20px; text-align: center; }
9
+ h1 { color: #333; }
10
+ .container { max-width: 800px; margin: auto; }
11
+ form { margin-top: 20px; padding: 20px; border: 1px solid #ddd; border-radius: 8px; }
12
+ .image-display { margin-top: 20px; }
13
+ .image-display img { max-width: 100%; border: 1px solid #ccc; border-radius: 8px; }
14
+ .error-message { color: red; font-weight: bold; }
15
+ </style>
16
+ </head>
17
+ <body>
18
+ <div class="container">
19
+ <h1>YOLO Object Detection Demo</h1>
20
+ <p>Upload an image to perform object detection.</p>
21
+
22
+ <form action="/predict_web" method="post" enctype="multipart/form-data">
23
+ <input type="file" name="file" accept="image/*" required>
24
+ <br><br>
25
+ <button type="submit">Upload and Detect</button>
26
+ </form>
27
+
28
+ {% if image_url %}
29
+ <div class="image-display">
30
+ <h2>Detection Result:</h2>
31
+ <img src="{{ image_url }}" alt="Detected Objects">
32
+ </div>
33
+ {% endif %}
34
+
35
+ {% if error_message %}
36
+ <div class="error-message">
37
+ <p>{{ error_message }}</p>
38
+ </div>
39
+ {% endif %}
40
+
41
+ </div>
42
+ </body>
43
+ </html>
test_api.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import io
4
+ import os
5
+
6
+ # The URL of your FastAPI predict endpoint
7
+ url = "http://127.0.0.1:8000/predict"
8
+
9
+
10
+ image_path = "acne-face-2-18.jpg"
11
+ output_path = "result.jpg"
12
+
13
+
14
+ COLORS = {
15
+ "acne": "red",
16
+ "melasma": "green",
17
+ "wrinkle": "blue"
18
+ }
19
+
20
+ def draw_boxes_on_image(image, detections):
21
+ """Draws bounding boxes, class names, and confidence scores on an image."""
22
+ draw = ImageDraw.Draw(image)
23
+ try:
24
+ # Try to use a better-looking font if available
25
+ font = ImageFont.truetype("arial.ttf", 20)
26
+ except IOError:
27
+ font = ImageFont.load_default()
28
+ print("Arial font not found, using default font.")
29
+
30
+ for detection in detections:
31
+ box = detection['box']
32
+ class_name = detection['class_name']
33
+ confidence = detection['confidence']
34
+
35
+ # Get color based on class name, defaulting to a solid color if not found
36
+ color = COLORS.get(class_name, "white")
37
+
38
+ # Draw the rectangle
39
+ draw.rectangle(box, outline=color, width=3)
40
+
41
+ # Create the label text with class name and confidence
42
+ label = f"{class_name}: {confidence:.2f}"
43
+
44
+ # Use textbbox() to get text dimensions
45
+ # It returns a tuple: (left, top, right, bottom)
46
+ bbox = draw.textbbox((0, 0), label, font=font)
47
+ text_width = bbox[2] - bbox[0]
48
+ text_height = bbox[3] - bbox[1]
49
+
50
+ # Define text position slightly above the top-left corner of the box
51
+ text_x = box[0]
52
+ text_y = box[1] - text_height - 5 # 5 pixels padding
53
+
54
+ # Ensure text is not drawn off the top of the image
55
+ if text_y < 0:
56
+ text_y = box[1] + 5 # Draw below the box if no space above
57
+
58
+ # Draw a filled background for the text for better visibility
59
+ draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], fill=color)
60
+
61
+ # Draw the label text
62
+ draw.text((text_x, text_y), label, fill="black", font=font)
63
+
64
+ return image
65
+
66
+ try:
67
+ # Check if the image file exists
68
+ if not os.path.exists(image_path):
69
+ raise FileNotFoundError(f"Error: The image file was not found at {image_path}")
70
+
71
+ # Open the image file in binary mode
72
+ with open(image_path, "rb") as f:
73
+ files = {"file": f}
74
+
75
+ # Send the POST request to the FastAPI endpoint
76
+ response = requests.post(url, files=files)
77
+
78
+ # Check for a successful response (status code 200)
79
+ if response.status_code == 200:
80
+ detections = response.json().get("detections", [])
81
+
82
+ if detections:
83
+ print("Detections found:", detections)
84
+ # Load the original image again for plotting
85
+ original_image = Image.open(image_path).convert("RGB")
86
+
87
+ # Draw the detections on the image
88
+ plotted_image = draw_boxes_on_image(original_image, detections)
89
+
90
+ # Save the new image with the plots
91
+ plotted_image.save(output_path)
92
+ print(f"Success! Plotted image saved to: {output_path}")
93
+
94
+ else:
95
+ print("No objects were detected.")
96
+
97
+ else:
98
+ print(f"Error: API returned status code {response.status_code}")
99
+ print("Response:", response.text)
100
+
101
+ except requests.exceptions.RequestException as e:
102
+ print(f"An error occurred while connecting to the API: {e}")