SathvikGanta commited on
Commit
b487594
·
verified ·
1 Parent(s): 457fad5

Update symbol_detection.py

Browse files
Files changed (1) hide show
  1. symbol_detection.py +10 -5
symbol_detection.py CHANGED
@@ -1,11 +1,16 @@
1
- from transformers import DetrFeatureExtractor, DetrForObjectDetection
 
2
 
3
- extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
 
4
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
5
 
6
  def detect_elements(image):
7
  """Detects elements within the diagram."""
8
- inputs = extractor(images=image, return_tensors="pt")
9
  outputs = model(**inputs)
10
- labels = outputs.logits.argmax(-1).numpy()
11
- return labels
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
 
4
+ # Use the new DetrImageProcessor
5
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
6
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
7
 
8
  def detect_elements(image):
9
  """Detects elements within the diagram."""
10
+ inputs = processor(images=image, return_tensors="pt")
11
  outputs = model(**inputs)
12
+
13
+ # Get the labels of the detected objects
14
+ logits = outputs.logits.argmax(-1).numpy()
15
+ return logits
16
+