Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import base64 | |
| from io import BytesIO | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| from openpyxl import Workbook, load_workbook | |
| # Load YOLOv8 model | |
| model = YOLO("best.pt") | |
| # Create directories if not present | |
| uploaded_folder = Path('Uploaded_Picture') | |
| predicted_folder = Path('Predicted_Picture') | |
| uploaded_folder.mkdir(parents=True, exist_ok=True) | |
| predicted_folder.mkdir(parents=True, exist_ok=True) | |
| # Path for Excel database file | |
| xlsx_db_file = Path('patient_predictions.xlsx') | |
| # Initialize Excel database file if not present | |
| if not xlsx_db_file.exists(): | |
| workbook = Workbook() | |
| sheet = workbook.active | |
| sheet.title = "Predictions" | |
| sheet.append(["Name", "Age", "Medical Record", "Sex", "Result", "Image Path"]) | |
| workbook.save(xlsx_db_file) | |
| def predict_image(input_image, name, age, medical_record, sex): | |
| if input_image is None: | |
| return None, "Please Input The Image" | |
| # Convert Gradio input image (PIL Image) to numpy array | |
| image_np = np.array(input_image) | |
| # Ensure the image is in the correct format | |
| if len(image_np.shape) == 2: # grayscale to RGB | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) | |
| elif image_np.shape[2] == 4: # RGBA to RGB | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) | |
| # Perform prediction | |
| results = model(image_np) | |
| # Draw bounding boxes on the image | |
| image_with_boxes = image_np.copy() | |
| raw_predictions = [] | |
| if results[0].boxes: | |
| # Sort the results by confidence and take the highest confidence one | |
| highest_confidence_result = max(results[0].boxes, key=lambda x: x.conf.item()) | |
| # Determine the label based on the class index | |
| class_index = highest_confidence_result.cls.item() | |
| if class_index == 0: | |
| label = "Immature" | |
| color = (0, 0, 255) # Blue for Immature | |
| elif class_index == 1: | |
| label = "Mature" | |
| color = (255, 0, 0) # Red for Mature | |
| else: | |
| label = "Normal" | |
| color = (0, 255, 0) # Green for Normal | |
| confidence = highest_confidence_result.conf.item() | |
| xmin, ymin, xmax, ymax = map(int, highest_confidence_result.xyxy[0]) | |
| # Calculate the average of box width and height | |
| box_width = xmax - xmin | |
| box_height = ymax - ymin | |
| avg_dimension = (box_width + box_height) / 2 | |
| # Calculate the circle radius as 1/12 of the average dimension | |
| radius = int(avg_dimension / 12) | |
| # Calculate the center of the bounding box | |
| center_x = int((xmin + xmax) / 2) | |
| center_y = int((ymin + ymax) / 2) | |
| # Draw the circle at the center of the bounding box with the color corresponding to the label | |
| cv2.circle(image_with_boxes, (center_x, center_y), radius, color, 2) | |
| # Enlarge font scale and thickness | |
| font_scale = 1.0 | |
| thickness = 2 | |
| # Calculate label background size | |
| (text_width, text_height), baseline = cv2.getTextSize(f'{label} {confidence:.2f}', cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
| cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED) | |
| # Put the label text with black background | |
| cv2.putText(image_with_boxes, f'{label} {confidence:.2f}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) | |
| raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Circle Center: [{center_x}, {center_y}], Radius: {radius}") | |
| raw_predictions_str = "\n".join(raw_predictions) | |
| # Convert to PIL image for further processing | |
| pil_image_with_boxes = Image.fromarray(image_with_boxes) | |
| # Add text and watermark | |
| pil_image_with_boxes = add_text_and_watermark(pil_image_with_boxes, name, age, medical_record, sex, label) | |
| # Save images to directories | |
| image_name = f"{name}-{age}-{sex}-{medical_record}.png" | |
| input_image.save(uploaded_folder / image_name) | |
| pil_image_with_boxes.save(predicted_folder / image_name) | |
| # Convert the predicted image to base64 for embedding in the XLSX file | |
| buffered = BytesIO() | |
| pil_image_with_boxes.save(buffered, format="PNG") | |
| predicted_image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Append the prediction to the XLSX database | |
| append_patient_info_to_xlsx(name, age, medical_record, sex, label, image_name) | |
| return pil_image_with_boxes, raw_predictions_str | |
| def add_watermark(image): | |
| try: | |
| logo = Image.open('image-logo.png').convert("RGBA") | |
| image = image.convert("RGBA") | |
| basewidth = 100 | |
| wpercent = (basewidth / float(logo.size[0])) | |
| hsize = int((float(wpercent) * logo.size[1])) | |
| logo = logo.resize((basewidth, hsize), Image.LANCZOS) | |
| position = (image.width - logo.width - 10, image.height - logo.height - 10) | |
| transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0)) | |
| transparent.paste(image, (0, 0)) | |
| transparent.paste(logo, position, mask=logo) | |
| return transparent.convert("RGB") | |
| except Exception as e: | |
| print(f"Error adding watermark: {e}") | |
| return image | |
| def add_text_and_watermark(image, name, age, medical_record, sex, label): | |
| draw = ImageDraw.Draw(image) | |
| font_size = 24 | |
| try: | |
| font = ImageFont.truetype("font.ttf", size=font_size) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| print("Error: cannot open resource, using default font.") | |
| text = f"Name: {name}, Age: {age}, Medical Record: {medical_record}, Sex: {sex}, Result: {label}" | |
| text_bbox = draw.textbbox((0, 0), text, font=font) | |
| text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] | |
| text_x = 20 | |
| text_y = 40 | |
| padding = 10 | |
| draw.rectangle([text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding], fill="black") | |
| draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font) | |
| image_with_watermark = add_watermark(image) | |
| return image_with_watermark | |
| def append_patient_info_to_xlsx(name, age, medical_record, sex, result, image_path): | |
| if not xlsx_db_file.exists(): | |
| workbook = Workbook() | |
| sheet = workbook.active | |
| sheet.title = "Predictions" | |
| sheet.append(["Name", "Age", "Medical Record", "Sex", "Result", "Image Path"]) | |
| workbook.save(xlsx_db_file) | |
| workbook = load_workbook(xlsx_db_file) | |
| sheet = workbook.active | |
| sheet.append([name, age, medical_record, sex, result, str(image_path)]) | |
| workbook.save(xlsx_db_file) | |
| return str(xlsx_db_file) | |
| def download_folder(folder): | |
| zip_path = os.path.join(tempfile.gettempdir(), f"{folder}.zip") | |
| shutil.make_archive(zip_path.replace('.zip', ''), 'zip', folder) | |
| return zip_path | |
| def interface(name, age, medical_record, sex, input_image): | |
| if input_image is None: | |
| return None, "Please upload an image.", None | |
| output_image, raw_result = predict_image(input_image, name, age, medical_record, sex) | |
| return output_image, raw_result, str(xlsx_db_file) | |
| def download_predicted_folder(): | |
| return download_folder(predicted_folder) | |
| def download_uploaded_folder(): | |
| return download_folder(uploaded_folder) | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# Cataract Detection System") | |
| gr.Markdown("Upload an image to detect cataract and add patient details.") | |
| gr.Markdown("This application uses YOLOv8 with mAP=0.981") | |
| with gr.Column(): | |
| name = gr.Textbox(label="Name") | |
| age = gr.Number(label="Age") | |
| medical_record = gr.Number(label="Medical Record") | |
| sex = gr.Radio(["Male", "Female"], label="Sex") | |
| input_image = gr.Image(type="pil", label="Upload an Image", image_mode="RGB") | |
| with gr.Column(): | |
| submit_btn = gr.Button("Submit") | |
| output_image = gr.Image(type="pil", label="Predicted Image") | |
| with gr.Row(): | |
| raw_result = gr.Textbox(label="Prediction Result") | |
| with gr.Row(): | |
| download_xlsx_btn = gr.Button("Download Patient Information (XLSX)") | |
| download_uploaded_btn = gr.Button("Download Uploaded Images") | |
| download_predicted_btn = gr.Button("Download Predicted Images") | |
| xlsx_file = gr.File(label="Patient Information XLSX File") | |
| uploaded_folder_file = gr.File(label="Uploaded Images Zip File") | |
| predicted_folder_file = gr.File(label="Predicted Images Zip File") | |
| submit_btn.click(fn=interface, inputs=[name, age, medical_record, sex, input_image], outputs=[output_image, raw_result, xlsx_file]) | |
| download_xlsx_btn.click(fn=lambda: str(xlsx_db_file), outputs=xlsx_file) | |
| download_uploaded_btn.click(fn=download_uploaded_folder, outputs=uploaded_folder_file) | |
| download_predicted_btn.click(fn=download_predicted_folder, outputs=predicted_folder_file) | |
| demo.launch() |