Spaces:
No application file
No application file
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from torchvision import transforms | |
| from torchvision.models import resnet50 | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from ultralytics import YOLO | |
| # Load YOLO model | |
| yolo_model = YOLO('/home/alsufyh/myfinal_project/YOLO11/runs/detect/train21/weights/best.pt') | |
| # Load ResNet model | |
| resnet_model = resnet50() | |
| resnet_model.fc = nn.Linear(2048, 5) # Modify for 5 classes | |
| state_dict_path = "/home/alsufyh/myfinal_project/classification/final/drug_classification_model.pth" | |
| state_dict = torch.load(state_dict_path) | |
| resnet_model.load_state_dict(state_dict, strict=False) | |
| resnet_model.eval() | |
| # Define the image transform (resize, convert to tensor, normalize) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization | |
| ]) | |
| # Function to process the image and text input | |
| def process_image_and_notes(image, text): | |
| # Perform YOLO detection | |
| results = yolo_model.predict(image) | |
| # Prepare results and visualization | |
| final_results = [] | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| ax.imshow(image) | |
| ax.axis("off") | |
| for box in results[0].boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates | |
| yolo_class = int(box.cls.item()) # YOLO-predicted class | |
| confidence = box.conf.item() # YOLO confidence score | |
| # Crop region of interest (ROI) using YOLO bounding box | |
| crop = image.crop((x1, y1, x2, y2)) | |
| # If YOLO confidence is low, refine with ResNet | |
| if confidence < 0.8: | |
| # Apply transformations for ResNet | |
| input_tensor = transform(crop).unsqueeze(0) # Add batch dimension | |
| # Perform ResNet inference | |
| with torch.no_grad(): | |
| output = resnet_model(input_tensor) | |
| refined_confidence, refined_class = torch.max(torch.softmax(output, dim=1), dim=1) | |
| refined_class = refined_class.item() | |
| refined_confidence = refined_confidence.item() | |
| # Replace YOLO predictions with ResNet predictions | |
| yolo_class = refined_class | |
| confidence = refined_confidence | |
| # Append results for further analysis | |
| final_results.append({ | |
| "object": yolo_class, | |
| "confidence": confidence | |
| }) | |
| # Draw bounding box and labels | |
| ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor="red", facecolor="none")) | |
| ax.text(x1, y1 - 10, f"Class: {yolo_class}, Conf: {confidence:.2f}", color="white", fontsize=12, | |
| bbox=dict(facecolor="red", alpha=0.5)) | |
| # Save the plot to a buffer | |
| buffer = BytesIO() | |
| plt.savefig(buffer, format="png") | |
| buffer.seek(0) | |
| # Return the results and processed image | |
| return f"Detected objects:\n{text}\nResults: {final_results}", buffer.getvalue() | |
| # Function for thumbs-up feedback | |
| def thumbs_up_fn(image, text, output): | |
| print("Thumbs up feedback received.") | |
| # Function for thumbs-down feedback | |
| def thumbs_down_fn(image, text, output): | |
| print("Thumbs down feedback received.") | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Narcotics Detection") | |
| gr.HTML('<h2>Upload Image</h2>') | |
| image_input = gr.Image(type="pil", label="Upload Image", show_label=False) | |
| gr.HTML('<h2>Enter Notes</h2>') | |
| text_input = gr.TextArea(label="Enter Notes", show_label=False) | |
| submit_button = gr.Button("Submit") | |
| result_output = gr.Textbox(label="Result", interactive=False) | |
| processed_image_output = gr.Image(label="Processed Image") | |
| # Feedback buttons | |
| with gr.Row(visible=False) as feedback_row: | |
| gr.Markdown("**Was the result helpful?**") | |
| thumbs_up = gr.Button("π") | |
| thumbs_down = gr.Button("π") | |
| # Button click actions | |
| submit_button.click( | |
| fn=process_image_and_notes, | |
| inputs=[image_input, text_input], | |
| outputs=[result_output, processed_image_output] | |
| ) | |
| submit_button.click(lambda: gr.update(visible=True), inputs=None, outputs=[feedback_row]) | |
| thumbs_up.click(fn=thumbs_up_fn, inputs=[image_input, text_input, result_output]) | |
| thumbs_down.click(fn=thumbs_down_fn, inputs=[image_input, text_input, result_output]) | |
| demo.launch(share=True) |