Mohansai2004 commited on
Commit
0290b84
·
verified ·
1 Parent(s): b6e5d1c

Update app/caption_model.py

Browse files
Files changed (1) hide show
  1. app/caption_model.py +43 -1
app/caption_model.py CHANGED
@@ -11,7 +11,49 @@ def caption_image(image: Image.Image):
11
  raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
12
 
13
  # Run object detection
14
- results = detector(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Track highest score per object
17
  objects_dict = {}
 
11
  raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
12
 
13
  # Run object detection
14
+ from transformers import pipeline
15
+ from PIL import Image
16
+
17
+ # Load object detection model
18
+ MODEL_NAME = "hustvl/yolos-small"
19
+ detector = pipeline("object-detection", model=MODEL_NAME)
20
+
21
+ def caption_image(image: Image.Image):
22
+ # Validate input
23
+ if not isinstance(image, Image.Image) or image.mode not in ('RGB', 'L'):
24
+ raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
25
+
26
+ # Run object detection with custom parameters
27
+ results = detector(image, top_k=20, threshold=0.2)
28
+
29
+ # Track highest score per object
30
+ objects_dict = {}
31
+ for result in results:
32
+ label = result['label']
33
+ score = result['score']
34
+ if label in objects_dict:
35
+ objects_dict[label] = max(objects_dict[label], score)
36
+ else:
37
+ objects_dict[label] = score
38
+
39
+ # Build structured list of objects
40
+ objects_list = [
41
+ {"label": label, "score": round(score, 2)}
42
+ for label, score in sorted(objects_dict.items(), key=lambda x: x[1], reverse=True)
43
+ ]
44
+
45
+ # Create readable caption
46
+ detected_objects = [f"{obj['label']} ({obj['score']:.2f})" for obj in objects_list]
47
+ caption = "Detected objects: " + ", ".join(detected_objects) if detected_objects else "No objects detected."
48
+
49
+ # Highest confidence score
50
+ max_confidence = max(objects_dict.values()) if objects_dict else 0.0
51
+
52
+ return {
53
+ "caption": caption,
54
+ "objects": objects_list,
55
+ "confidence": round(max_confidence, 2)
56
+ }
57
 
58
  # Track highest score per object
59
  objects_dict = {}