app_do / me_app.py
eng-hassan's picture
Upload me_app.py
5e2ec8d verified
import gradio as gr
import matplotlib.pyplot as plt
from io import BytesIO
from torchvision import transforms
from torchvision.models import resnet50
from PIL import Image
import torch
import torch.nn as nn
from ultralytics import YOLO
# Load YOLO model
yolo_model = YOLO('/home/alsufyh/myfinal_project/YOLO11/runs/detect/train21/weights/best.pt')
# Load ResNet model
resnet_model = resnet50()
resnet_model.fc = nn.Linear(2048, 5) # Modify for 5 classes
state_dict_path = "/home/alsufyh/myfinal_project/classification/final/drug_classification_model.pth"
state_dict = torch.load(state_dict_path)
resnet_model.load_state_dict(state_dict, strict=False)
resnet_model.eval()
# Define the image transform (resize, convert to tensor, normalize)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
])
# Function to process the image and text input
def process_image_and_notes(image, text):
# Perform YOLO detection
results = yolo_model.predict(image)
# Prepare results and visualization
final_results = []
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(image)
ax.axis("off")
for box in results[0].boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
yolo_class = int(box.cls.item()) # YOLO-predicted class
confidence = box.conf.item() # YOLO confidence score
# Crop region of interest (ROI) using YOLO bounding box
crop = image.crop((x1, y1, x2, y2))
# If YOLO confidence is low, refine with ResNet
if confidence < 0.8:
# Apply transformations for ResNet
input_tensor = transform(crop).unsqueeze(0) # Add batch dimension
# Perform ResNet inference
with torch.no_grad():
output = resnet_model(input_tensor)
refined_confidence, refined_class = torch.max(torch.softmax(output, dim=1), dim=1)
refined_class = refined_class.item()
refined_confidence = refined_confidence.item()
# Replace YOLO predictions with ResNet predictions
yolo_class = refined_class
confidence = refined_confidence
# Append results for further analysis
final_results.append({
"object": yolo_class,
"confidence": confidence
})
# Draw bounding box and labels
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor="red", facecolor="none"))
ax.text(x1, y1 - 10, f"Class: {yolo_class}, Conf: {confidence:.2f}", color="white", fontsize=12,
bbox=dict(facecolor="red", alpha=0.5))
# Save the plot to a buffer
buffer = BytesIO()
plt.savefig(buffer, format="png")
buffer.seek(0)
# Return the results and processed image
return f"Detected objects:\n{text}\nResults: {final_results}", buffer.getvalue()
# Function for thumbs-up feedback
def thumbs_up_fn(image, text, output):
print("Thumbs up feedback received.")
# Function for thumbs-down feedback
def thumbs_down_fn(image, text, output):
print("Thumbs down feedback received.")
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Narcotics Detection")
gr.HTML('<h2>Upload Image</h2>')
image_input = gr.Image(type="pil", label="Upload Image", show_label=False)
gr.HTML('<h2>Enter Notes</h2>')
text_input = gr.TextArea(label="Enter Notes", show_label=False)
submit_button = gr.Button("Submit")
result_output = gr.Textbox(label="Result", interactive=False)
processed_image_output = gr.Image(label="Processed Image")
# Feedback buttons
with gr.Row(visible=False) as feedback_row:
gr.Markdown("**Was the result helpful?**")
thumbs_up = gr.Button("πŸ‘")
thumbs_down = gr.Button("πŸ‘Ž")
# Button click actions
submit_button.click(
fn=process_image_and_notes,
inputs=[image_input, text_input],
outputs=[result_output, processed_image_output]
)
submit_button.click(lambda: gr.update(visible=True), inputs=None, outputs=[feedback_row])
thumbs_up.click(fn=thumbs_up_fn, inputs=[image_input, text_input, result_output])
thumbs_down.click(fn=thumbs_down_fn, inputs=[image_input, text_input, result_output])
demo.launch(share=True)