jenshen69 commited on
Commit
b8d4d64
·
verified ·
1 Parent(s): fab4ffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -10
app.py CHANGED
@@ -3,24 +3,58 @@ from transformers import DetrImageProcessor, DetrForObjectDetection
3
  import torch
4
  from PIL import Image
5
 
 
6
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
7
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
8
 
9
  def detect_objects(image):
 
 
 
 
 
10
  inputs = processor(images=image, return_tensors="pt")
11
- outputs = model(**inputs)
 
 
 
12
  target_sizes = torch.tensor([image.size[::-1]])
13
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Подготовка данных для AnnotatedImage
16
- boxes = results["boxes"].tolist()
17
- labels = [model.config.id2label[label.item()] for label in results["labels"]]
18
- return (image, boxes, labels) # формат: (изображение, bbox, метки)
 
 
19
 
 
20
  demo = gr.Interface(
21
  fn=detect_objects,
22
- inputs=gr.Image(type="pil"),
23
- outputs=gr.AnnotatedImage(), # Показывает bbox на изображении
24
- title="Детектор драк"
 
 
 
 
 
 
 
 
 
25
  )
26
- demo.launch()
 
 
 
3
  import torch
4
  from PIL import Image
5
 
6
+ # Загрузка модели и процессора (кешируется при первом запуске)
7
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
8
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
9
 
10
  def detect_objects(image):
11
+ # Преобразуем входное изображение
12
+ if isinstance(image, str): # если путь к файлу
13
+ image = Image.open(image)
14
+
15
+ # Детекция объектов
16
  inputs = processor(images=image, return_tensors="pt")
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+ # Постобработка результатов
21
  target_sizes = torch.tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(
23
+ outputs,
24
+ target_sizes=target_sizes,
25
+ threshold=0.7
26
+ )[0]
27
+
28
+ # Форматирование результатов для AnnotatedImage
29
+ annotations = []
30
+ for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
31
+ box = [round(i, 2) for i in box.tolist()] # округляем координаты
32
+ label_text = f"{model.config.id2label[label.item()]} ({round(score.item(), 2)})"
33
+ annotations.append((box, label_text))
34
 
35
+ # Проверка на наличие драк (если обнаружено >1 человека)
36
+ people_count = sum(1 for label in results["labels"] if label.item() == 1)
37
+ if people_count >= 2:
38
+ annotations.append(([0, 0, 100, 30], "⚠️ Potential fight!"))
39
+
40
+ return (image, annotations)
41
 
42
+ # Создание интерфейса
43
  demo = gr.Interface(
44
  fn=detect_objects,
45
+ inputs=gr.Image(type="pil", label="Input Image"),
46
+ outputs=gr.AnnotatedImage(
47
+ label="Detection Results",
48
+ show_legend=True
49
+ ),
50
+ title="Fight Detection with DETR",
51
+ description="Upload an image to detect people and potential fights. Model: facebook/detr-resnet-50",
52
+ examples=[
53
+ ["example1.jpg"], # добавьте свои примеры во вкладке Files
54
+ ["example2.jpg"]
55
+ ],
56
+ allow_flagging="never"
57
  )
58
+
59
+ # Запуск приложения
60
+ demo.launch(debug=True)