File size: 3,640 Bytes
b27cd24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import os
from time import sleep
from tqdm import tqdm
from PIL import Image
from google import genai
from google.genai import types
from pydantic import BaseModel

# =========================
# Gemini setup
# =========================
GEMINI_API_KEY = "AIzaSyAfnBWMguUci9GyzW-gBxrxCfOmMExiDnA"
client = genai.Client(api_key=GEMINI_API_KEY)

MODEL_ID = "gemini-3-flash-preview"  # or newer if available


# =========================
# JSON schema (STRICT)
# =========================
class QueueAnswer(BaseModel):
    number_of_people: int
    line_direction: str
    end_visible: str
    end_location: str
    end_camera_direction: str
    end_person_description: str
    start_visible: str
    start_location: str
    start_camera_direction: str
    start_person_description: str


# =========================
# Prompt (same as GPT version)
# =========================
def build_prompt():
    return (
        "You are an expert at analyzing a single image of a line of people.\n"
        "Return STRICT JSON only.\n\n"

        "Fields:\n"
        'number_of_people: integer\n'
        'line_direction: ["towards", "away", "sideways-left", "sideways-right"]\n'
        'end_visible: ["yes","no"]\n'
        'end_location: ["far left","center left","center","center right","far right","N/A"]\n'
        'end_camera_direction: ["left","right","back","N/A"]\n'
        'end_person_description: string\n'
        'start_visible: ["yes","no"]\n'
        'start_location: ["far left","center left","center","center right","far right","N/A"]\n'
        'start_camera_direction: ["left","right","back","N/A"]\n'
        'start_person_description: string\n'
    )


# =========================
# Single image inference
# =========================
def analyze_image(img_path):
    image = Image.open(img_path)
    image.thumbnail([512, 512])  # Gemini requirement

    response = client.models.generate_content(
        model=MODEL_ID,
        contents=[
            "Return ONLY JSON.",
            build_prompt(),
            image
        ],
        config=types.GenerateContentConfig(
            response_mime_type="application/json",
            response_schema=QueueAnswer,
            temperature=0.2,
        ),
    )

    return response.text


# =========================
# Batch processing
# =========================
def generate_reranking(image_paths, output_file):
    with open(output_file, "a", encoding="utf-8") as f:
        for img_path in tqdm(image_paths):
            basename = os.path.basename(img_path)

            try:
                result = analyze_image(img_path)

                try:
                    parsed = json.loads(result)
                except:
                    parsed = {"error": "invalid_json", "raw": result}

                f.write(basename + "\n")
                f.write(json.dumps(parsed) + "\n\n")
                f.flush()

            except Exception as e:
                print(f"Error: {img_path} -> {e}")
                f.write(basename + "\n")
                f.write(json.dumps({"error": str(e)}) + "\n\n")

            sleep(0.2)


# =========================
# Load images (same as before)
# =========================
root = "/scratch/ds5725/linefinder/LineFinder/Images"
subfolders = ["QueuesOutdoors","QueuesInSupermarketNew","QueuesInThemeParks"]

all_files = []
for sub in subfolders:
    folder_path = os.path.join(root, sub)
    for dirpath, _, filenames in os.walk(folder_path):
        for f in filenames:
            all_files.append(os.path.abspath(os.path.join(dirpath, f)))

all_files.sort()

generate_reranking(all_files, "gemini_line_luna_olivia.jsonl")