jeyasee commited on
Commit
394311c
·
1 Parent(s): 0a600c0

Add application file

Browse files
Files changed (2) hide show
  1. app.py +53 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import gradio as gr
5
+ from transformers import DetrImageProcessor, DetrForObjectDetection
6
+
7
+ # Load model and processor
8
+ model_name = "facebook/detr-resnet-50"
9
+ processor = DetrImageProcessor.from_pretrained(model_name)
10
+ model = DetrForObjectDetection.from_pretrained(model_name)
11
+
12
+ # Helper to draw boxes
13
+ def draw_boxes(image, outputs, threshold=0.9):
14
+ draw = ImageDraw.Draw(image)
15
+ font = ImageFont.load_default()
16
+
17
+ labels = outputs["labels"]
18
+ boxes = outputs["boxes"]
19
+ scores = outputs["scores"]
20
+
21
+ for score, label, box in zip(scores, labels, boxes):
22
+ if score >= threshold:
23
+ box = [round(i, 2) for i in box.tolist()]
24
+ draw.rectangle(box, outline="red", width=3)
25
+ text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
26
+ draw.text((box[0], box[1]), text, fill="white", font=font)
27
+
28
+ return image
29
+
30
+ # Inference function
31
+ def detect_objects(image):
32
+ inputs = processor(images=image, return_tensors="pt")
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+
36
+ # Post-process results
37
+ target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
38
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
39
+ processed_img = image.copy()
40
+ result_img = draw_boxes(processed_img, results)
41
+ return result_img
42
+
43
+ # Gradio interface
44
+ app = gr.Interface(
45
+ fn=detect_objects,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs=gr.Image(type="pil"),
48
+ title="Image Detector Agent",
49
+ description="Upload an image to detect objects using a pretrained DETR model."
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ matplotlib
5
+ transformers
6
+ pandas
7
+ numpy
8
+ seaborn
9
+ scikit-learn
10
+ opencv-python
11
+ altair<5