3v324v23 commited on
Commit
eff2be4
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large files, to be downloaded via huggingface.
2
+ g3/index/G3.index
3
+ g3/checkpoints/mercator_finetune_weight.pth
4
+ g3/data/mp16/MP16_Pro_filtered.csv
5
+ index
6
+ checkpoints
7
+ data
8
+
9
+ # venv and dev stuff
10
+ linuxenv
11
+ myenv
12
+ .venv
13
+ .env
14
+ acmmm2025-grand-challenge-gg-credentials.json
15
+ cred.json
16
+ **/__pycache__/
17
+ pyproject.toml
18
+ uv.lock
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim-bullseye
2
+ WORKDIR /code
3
+
4
+ RUN apt-get update && apt-get install -y ffmpeg xvfb
5
+
6
+ COPY ./requirements.txt /code/requirements.txt
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+ RUN playwright install chrome
9
+ RUN playwright install-deps
10
+ RUN playwright install
11
+
12
+ COPY ./src /code/src
13
+ RUN python /code/src/setup.py
14
+
15
+ COPY ./app.py /code/app.py
16
+ COPY ./entrypoint.sh /code/entrypoint.sh
17
+
18
+ RUN chmod +x /code/entrypoint.sh
19
+ ENTRYPOINT [ "/code/entrypoint.sh" ]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # G3 Geolocation Service
2
+
3
+ This is a containerized geolocation service based on the paper "G3: An Effective and Adaptive Framework for Worldwide Geolocalization Using Large Multi-Modality Models". The service is augmented with multilayer verification for location and evidence.
4
+
5
+ ## Prerequisites
6
+
7
+ - Docker with GPU support
8
+ - NVIDIA Container Toolkit (for GPU access)
9
+ - Required API keys (see Environment Variables section)
10
+
11
+ ## Quick Start
12
+
13
+ ### 1. Prepare Environment File
14
+
15
+ Create a `.env` file with the following variables:
16
+
17
+ ```bash
18
+ GOOGLE_CLOUD_API_KEY=your_google_cloud_api_key
19
+ GOOGLE_CSE_CX=your_google_custom_search_engine_id
20
+ SCRAPINGDOG_API_KEY=your_scrapingdog_api_key
21
+ IMGBB_API_KEY=your_imgbb_api_key
22
+ GOOGLE_APPLICATION_CREDENTIALS=/code/path/to/your/credentials.json
23
+ ```
24
+
25
+ ### 2. Prepare Google Cloud Credentials
26
+
27
+ Ensure you have a Google Cloud service account JSON credentials file ready for copying to the container.
28
+
29
+ ### 3. Build Docker Image
30
+
31
+ ```bash
32
+ docker build -t g3-geolocation .
33
+ ```
34
+
35
+ ### 4. Create Docker Container
36
+
37
+ ```bash
38
+ docker create --name g3-container -p 80:80 --gpus=all --env-file .env g3-geolocation
39
+ ```
40
+
41
+ ### 5. Copy Credentials to Container
42
+
43
+ ```bash
44
+ docker cp /path/to/your/credentials.json g3-container:/code/
45
+ ```
46
+
47
+ ### 6. Start Container
48
+
49
+ ```bash
50
+ docker start g3-container
51
+ ```
52
+
53
+ ## Usage
54
+
55
+ Once the container is running, the service will be available at `http://localhost:80`.
56
+
57
+ ### API Endpoints
58
+
59
+ - **POST** `/g3/predict` - Submit images/videos for geolocation prediction
60
+ - **GET** `/g3/openapi` - Get OpenAPI specification
61
+
62
+ ### Example Request
63
+
64
+ ```bash
65
+ curl -X POST "http://localhost:80/g3/predict" \
66
+ -H "Content-Type: multipart/form-data" \
67
+ -F "files=@your_image.jpg"
68
+ ```
69
+
70
+ ## Environment Variables
71
+
72
+ | Variable | Description | Required |
73
+ | -------------------------------- | ------------------------------------------ | -------- |
74
+ | `GOOGLE_CLOUD_API_KEY` | Google Cloud API key for Gemini and Custom Google Search API | Yes |
75
+ | `GOOGLE_CSE_CX` | Google Custom Search Engine ID | Yes |
76
+ | `SCRAPINGDOG_API_KEY` | ScrapingDog API key for web scraping | Yes |
77
+ | `IMGBB_API_KEY` | ImgBB API key for image hosting | Yes |
78
+ | `GOOGLE_APPLICATION_CREDENTIALS` | Path to Google Cloud credentials JSON file | Yes |
79
+
80
+ ## API Keys Setup
81
+
82
+ ### Google Cloud API Key
83
+
84
+ 1. Go to [Google Cloud Console](https://console.cloud.google.com/)
85
+ 2. Enable Gemini API and Vision API
86
+ 3. Create an API key in the Credentials section
87
+
88
+ ### Google Custom Search Engine
89
+
90
+ 1. Go to [Google Custom Search](https://cse.google.com/)
91
+ 2. Create a new search engine
92
+ 3. Copy the Search Engine ID (CX)
93
+
94
+ ### ScrapingDog API Key
95
+
96
+ 1. Sign up at [ScrapingDog](https://scrapingdog.com/)
97
+ 2. Get your API key from the dashboard
98
+
99
+ ### ImgBB API Key
100
+
101
+ 1. Sign up at [ImgBB](https://imgbb.com/)
102
+ 2. Get your API key from the API section
103
+
104
+ ## Container Management
105
+
106
+ ### View Logs
107
+
108
+ ```bash
109
+ docker logs g3-container
110
+ ```
111
+
112
+ ### Stop Container
113
+
114
+ ```bash
115
+ docker stop g3-container
116
+ ```
117
+
118
+ ### Remove Container
119
+
120
+ ```bash
121
+ docker rm g3-container
122
+ ```
123
+
124
+ ### Remove Image
125
+
126
+ ```bash
127
+ docker rmi g3-geolocation
128
+ ```
129
+
130
+ ## Troubleshooting
131
+
132
+ ### GPU Access Issues
133
+
134
+ Ensure NVIDIA Container Toolkit is properly installed:
135
+
136
+ ```bash
137
+ nvidia-smi
138
+ docker run --rm --gpus all nvidia/cuda:11.0-base-ubuntu20.04 nvidia-smi
139
+ ```
140
+
141
+ ### API Key Issues
142
+
143
+ - Verify all API keys are valid and have proper permissions
144
+ - Check that the credentials file is properly copied to the container
145
+ - Ensure the `GOOGLE_APPLICATION_CREDENTIALS` path matches the copied file location
146
+
147
+ ### Memory Issues
148
+
149
+ If you encounter out-of-memory errors, consider:
150
+
151
+ - Reducing image sizes before upload
152
+ - Using a machine with more RAM/VRAM
153
+ - Adjusting batch processing parameters
154
+
155
+ ## Citation
156
+
157
+ ```bib
158
+ @article{jia2024g3,
159
+ title={G3: an effective and adaptive framework for worldwide geolocalization using large multi-modality models},
160
+ author={Jia, Pengyue and Liu, Yiding and Li, Xiaopeng and Zhao, Xiangyu and Wang, Yuhao and Du, Yantong and Han, Xiao and Wei, Xuetao and Wang, Shuaiqiang and Yin, Dawei},
161
+ journal={Advances in Neural Information Processing Systems},
162
+ volume={37},
163
+ pages={53198--53221},
164
+ year={2024}
165
+ }
166
+ ```
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ from contextlib import asynccontextmanager
6
+ from typing import Annotated, Optional
7
+
8
+ import torch
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI, File, HTTPException, UploadFile, status
11
+ from pydantic import BaseModel, Field
12
+
13
+ from src.g3_batch_prediction import G3BatchPredictor
14
+
15
+ from src.utils import load_images_as_base64
16
+
17
+
18
+ class EvidenceResponse(BaseModel):
19
+ analysis: Annotated[
20
+ str,
21
+ Field(description="A supporting analysis for the prediction."),
22
+ ]
23
+ references: Annotated[
24
+ list[str],
25
+ Field(description="Links or base64-encoded JPEG supporting the analysis."),
26
+ ] = []
27
+
28
+
29
+ class LocationPredictionResponse(BaseModel):
30
+ latitude: Annotated[
31
+ float,
32
+ Field(description="Latitude of the predicted location, in degree."),
33
+ ]
34
+ longitude: Annotated[
35
+ float,
36
+ Field(description="Longitude of the predicted location, in degree."),
37
+ ]
38
+ location: Annotated[
39
+ str,
40
+ Field(description="Textual description of the predicted location."),
41
+ ]
42
+ evidence: Annotated[
43
+ list[EvidenceResponse],
44
+ Field(description="List of supporting analyses for the prediction."),
45
+ ]
46
+
47
+
48
+ class PredictionResponse(BaseModel):
49
+ prediction: Annotated[
50
+ LocationPredictionResponse,
51
+ Field(description="The location prediction and accompanying analysis."),
52
+ ]
53
+ transcript: Annotated[
54
+ str | None,
55
+ Field(description="The extracted and concatenated transcripts, if any."),
56
+ ] = None
57
+ media: Optional[list[str]] = Field(
58
+ default=None,
59
+ description="List of media files processed during prediction."
60
+ )
61
+
62
+
63
+ predictor: G3BatchPredictor
64
+
65
+
66
+ @asynccontextmanager
67
+ async def lifespan(app: FastAPI):
68
+ load_dotenv()
69
+
70
+ with open("openapi.json", "wt") as api_file:
71
+ json.dump(app.openapi(), api_file, indent=4)
72
+
73
+ global predictor
74
+ predictor = G3BatchPredictor(device="cuda" if torch.cuda.is_available() else "cpu")
75
+
76
+ yield
77
+
78
+ del predictor
79
+
80
+
81
+ app = FastAPI(
82
+ lifespan=lifespan,
83
+ title="G3",
84
+ description="An endpoint to predict GPS coordinate from static image,"
85
+ " using G3 Framework.",
86
+ )
87
+
88
+
89
+ @app.post(
90
+ "/g3/predict",
91
+ description="Provide location prediction.",
92
+ )
93
+ async def predict_endpoint(
94
+ files: Annotated[
95
+ list[UploadFile],
96
+ File(description="Input images, videos and metadata json."),
97
+ ],
98
+ ) -> PredictionResponse:
99
+ # Write files to disk
100
+ try:
101
+ predictor.clear_directories()
102
+ for file in files:
103
+ filename = file.filename if file.filename is not None else uuid.uuid4().hex
104
+ filepath = predictor.input_dir / filename
105
+ os.makedirs(predictor.input_dir, exist_ok=True)
106
+ with open(filepath, "wb") as buffer:
107
+ shutil.copyfileobj(file.file, buffer)
108
+ except Exception as e:
109
+ raise HTTPException(
110
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
111
+ detail=f"Failed to save file: {e}",
112
+ )
113
+
114
+ # Get prediction
115
+ response = await predictor.predict(model_name="gemini-2.5-pro")
116
+ # response = predictor.get_response(response)
117
+ prediction = LocationPredictionResponse(
118
+ latitude=response.latitude,
119
+ longitude=response.longitude,
120
+ location=response.location,
121
+ evidence=[
122
+ EvidenceResponse(analysis=ev.analysis, references=ev.references)
123
+ for ev in response.evidence
124
+ ],
125
+ )
126
+ # Get transcript if available
127
+ transcript = predictor.get_transcript()
128
+
129
+ # Get media files if available
130
+ images_b64 = load_images_as_base64()
131
+
132
+ # Clear directories
133
+ return PredictionResponse(prediction=prediction, transcript=transcript, media=images_b64)
134
+
135
+
136
+ @app.get(
137
+ "/g3/openapi",
138
+ description="Provide the OpenAPI JSON describing this service's endpoints.",
139
+ )
140
+ async def openapi():
141
+ return app.openapi()
docker-compose.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ web:
3
+ build: .
4
+ ports:
5
+ - "8000:80"
6
+ environment:
7
+ - GOOGLE_APPLICATION_CREDENTIALS=/code/keys/credentials.json
8
+ volumes:
9
+ - ./.env:/code/.env
10
+
11
+ - ./keys:/code/keys
12
+
13
+ - ./entrypoint.sh:/code/entrypoint.sh
14
+
15
+ env_file:
16
+ - ./.env
17
+
18
+ restart: unless-stopped
19
+ runtime: nvidia
entrypoint.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # --- cleanup any stale Xvfb lock/socket ---
5
+ if [ -e /tmp/.X99-lock ]; then
6
+ echo "[entrypoint] removing stale /tmp/.X99-lock" >&2
7
+ rm -f /tmp/.X99-lock
8
+ fi
9
+ if [ -e /tmp/.X11-unix/X99 ]; then
10
+ echo "[entrypoint] removing stale /tmp/.X11-unix/X99" >&2
11
+ rm -f /tmp/.X11-unix/X99
12
+ fi
13
+
14
+ # --- start the virtual display ---
15
+ echo "[entrypoint] starting Xvfb on :99" >&2
16
+ Xvfb :99 -screen 0 1920x1080x24 &
17
+
18
+ # --- point GUI apps at it ---
19
+ export DISPLAY=:99
20
+ echo "[entrypoint] DISPLAY set to $DISPLAY" >&2
21
+
22
+ # --- launch FastAPI ---
23
+ echo "[entrypoint] exec fastapi" >&2
24
+ exec fastapi run app.py --port 80
openapi.json ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "openapi": "3.1.0",
3
+ "info": {
4
+ "title": "G3",
5
+ "description": "An endpoint to predict GPS coordinate from static image, using G3 Framework.",
6
+ "version": "0.1.0"
7
+ },
8
+ "paths": {
9
+ "/g3/predict": {
10
+ "post": {
11
+ "summary": "Predict Endpoint",
12
+ "description": "Provide location prediction.",
13
+ "operationId": "predict_endpoint_g3_predict_post",
14
+ "requestBody": {
15
+ "content": {
16
+ "multipart/form-data": {
17
+ "schema": {
18
+ "$ref": "#/components/schemas/Body_predict_endpoint_g3_predict_post"
19
+ }
20
+ }
21
+ },
22
+ "required": true
23
+ },
24
+ "responses": {
25
+ "200": {
26
+ "description": "Successful Response",
27
+ "content": {
28
+ "application/json": {
29
+ "schema": {
30
+ "$ref": "#/components/schemas/PredictionResponse"
31
+ }
32
+ }
33
+ }
34
+ },
35
+ "422": {
36
+ "description": "Validation Error",
37
+ "content": {
38
+ "application/json": {
39
+ "schema": {
40
+ "$ref": "#/components/schemas/HTTPValidationError"
41
+ }
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ },
48
+ "/g3/openapi": {
49
+ "get": {
50
+ "summary": "Openapi",
51
+ "description": "Provide the OpenAPI JSON describing this service's endpoints.",
52
+ "operationId": "openapi_g3_openapi_get",
53
+ "responses": {
54
+ "200": {
55
+ "description": "Successful Response",
56
+ "content": {
57
+ "application/json": {
58
+ "schema": {}
59
+ }
60
+ }
61
+ }
62
+ }
63
+ }
64
+ }
65
+ },
66
+ "components": {
67
+ "schemas": {
68
+ "Body_predict_endpoint_g3_predict_post": {
69
+ "properties": {
70
+ "files": {
71
+ "items": {
72
+ "type": "string",
73
+ "format": "binary"
74
+ },
75
+ "type": "array",
76
+ "title": "Files",
77
+ "description": "Input images, videos and metadata json."
78
+ }
79
+ },
80
+ "type": "object",
81
+ "required": [
82
+ "files"
83
+ ],
84
+ "title": "Body_predict_endpoint_g3_predict_post"
85
+ },
86
+ "EvidenceResponse": {
87
+ "properties": {
88
+ "analysis": {
89
+ "type": "string",
90
+ "title": "Analysis",
91
+ "description": "A supporting analysis for the prediction."
92
+ },
93
+ "references": {
94
+ "items": {
95
+ "type": "string"
96
+ },
97
+ "type": "array",
98
+ "title": "References",
99
+ "description": "Links or base64-encoded JPEG supporting the analysis.",
100
+ "default": []
101
+ }
102
+ },
103
+ "type": "object",
104
+ "required": [
105
+ "analysis"
106
+ ],
107
+ "title": "EvidenceResponse"
108
+ },
109
+ "HTTPValidationError": {
110
+ "properties": {
111
+ "detail": {
112
+ "items": {
113
+ "$ref": "#/components/schemas/ValidationError"
114
+ },
115
+ "type": "array",
116
+ "title": "Detail"
117
+ }
118
+ },
119
+ "type": "object",
120
+ "title": "HTTPValidationError"
121
+ },
122
+ "LocationPredictionResponse": {
123
+ "properties": {
124
+ "latitude": {
125
+ "type": "number",
126
+ "title": "Latitude",
127
+ "description": "Latitude of the predicted location, in degree."
128
+ },
129
+ "longitude": {
130
+ "type": "number",
131
+ "title": "Longitude",
132
+ "description": "Longitude of the predicted location, in degree."
133
+ },
134
+ "location": {
135
+ "type": "string",
136
+ "title": "Location",
137
+ "description": "Textual description of the predicted location."
138
+ },
139
+ "evidence": {
140
+ "items": {
141
+ "$ref": "#/components/schemas/EvidenceResponse"
142
+ },
143
+ "type": "array",
144
+ "title": "Evidence",
145
+ "description": "List of supporting analyses for the prediction."
146
+ }
147
+ },
148
+ "type": "object",
149
+ "required": [
150
+ "latitude",
151
+ "longitude",
152
+ "location",
153
+ "evidence"
154
+ ],
155
+ "title": "LocationPredictionResponse"
156
+ },
157
+ "PredictionResponse": {
158
+ "properties": {
159
+ "prediction": {
160
+ "$ref": "#/components/schemas/LocationPredictionResponse",
161
+ "description": "The location prediction and accompanying analysis."
162
+ },
163
+ "transcript": {
164
+ "anyOf": [
165
+ {
166
+ "type": "string"
167
+ },
168
+ {
169
+ "type": "null"
170
+ }
171
+ ],
172
+ "title": "Transcript",
173
+ "description": "The extracted and concatenated transcripts, if any."
174
+ },
175
+ "media": {
176
+ "anyOf": [
177
+ {
178
+ "items": {
179
+ "type": "string"
180
+ },
181
+ "type": "array"
182
+ },
183
+ {
184
+ "type": "null"
185
+ }
186
+ ],
187
+ "title": "Media",
188
+ "description": "List of media files processed during prediction."
189
+ }
190
+ },
191
+ "type": "object",
192
+ "required": [
193
+ "prediction"
194
+ ],
195
+ "title": "PredictionResponse"
196
+ },
197
+ "ValidationError": {
198
+ "properties": {
199
+ "loc": {
200
+ "items": {
201
+ "anyOf": [
202
+ {
203
+ "type": "string"
204
+ },
205
+ {
206
+ "type": "integer"
207
+ }
208
+ ]
209
+ },
210
+ "type": "array",
211
+ "title": "Location"
212
+ },
213
+ "msg": {
214
+ "type": "string",
215
+ "title": "Message"
216
+ },
217
+ "type": {
218
+ "type": "string",
219
+ "title": "Error Type"
220
+ }
221
+ },
222
+ "type": "object",
223
+ "required": [
224
+ "loc",
225
+ "msg",
226
+ "type"
227
+ ],
228
+ "title": "ValidationError"
229
+ }
230
+ }
231
+ }
232
+ }
requirements.txt ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --no-hashes --format requirements-txt
3
+ annotated-types==0.7.0
4
+ # via pydantic
5
+ anyio==4.9.0
6
+ # via
7
+ # google-genai
8
+ # httpx
9
+ # starlette
10
+ # watchfiles
11
+ cachetools==5.5.2
12
+ # via google-auth
13
+ certifi==2025.7.14
14
+ # via
15
+ # httpcore
16
+ # httpx
17
+ # pyproj
18
+ # requests
19
+ # sentry-sdk
20
+ charset-normalizer==3.4.2
21
+ # via requests
22
+ click==8.2.1
23
+ # via
24
+ # rich-toolkit
25
+ # typer
26
+ # uvicorn
27
+ colorama==0.4.6 ; sys_platform == 'win32'
28
+ # via
29
+ # click
30
+ # tqdm
31
+ # uvicorn
32
+ decorator==5.2.1
33
+ # via moviepy
34
+ dnspython==2.7.0
35
+ # via email-validator
36
+ einops==0.8.1
37
+ # via acmmm25-grand-challenge-geolocation
38
+ email-validator==2.2.0
39
+ # via
40
+ # fastapi
41
+ # pydantic
42
+ faiss-gpu-cu12==1.11.0
43
+ # via acmmm25-grand-challenge-geolocation
44
+ fastapi==0.116.1
45
+ # via acmmm25-grand-challenge-geolocation
46
+ fastapi-cli==0.0.8
47
+ # via fastapi
48
+ fastapi-cloud-cli==0.1.4
49
+ # via fastapi-cli
50
+ ffmpy==0.6.0
51
+ # via katna
52
+ filelock==3.18.0
53
+ # via
54
+ # huggingface-hub
55
+ # torch
56
+ # transformers
57
+ fsspec==2025.7.0
58
+ # via
59
+ # huggingface-hub
60
+ # torch
61
+ ftfy==6.3.1
62
+ # via open-clip-torch
63
+ geographiclib==2.0
64
+ # via geopy
65
+ geopy==2.4.1
66
+ # via acmmm25-grand-challenge-geolocation
67
+ google-api-core==2.25.1
68
+ # via
69
+ # google-cloud-videointelligence
70
+ # google-cloud-vision
71
+ google-auth==2.40.3
72
+ # via
73
+ # google-api-core
74
+ # google-cloud-videointelligence
75
+ # google-cloud-vision
76
+ # google-genai
77
+ google-cloud-videointelligence==2.16.2
78
+ # via acmmm25-grand-challenge-geolocation
79
+ google-cloud-vision==3.10.2
80
+ # via acmmm25-grand-challenge-geolocation
81
+ google-genai==1.26.0
82
+ # via acmmm25-grand-challenge-geolocation
83
+ googleapis-common-protos==1.70.0
84
+ # via
85
+ # google-api-core
86
+ # grpcio-status
87
+ greenlet==3.2.3
88
+ # via playwright
89
+ grpcio==1.73.1
90
+ # via
91
+ # google-api-core
92
+ # grpcio-status
93
+ grpcio-status==1.73.1
94
+ # via google-api-core
95
+ h11==0.16.0
96
+ # via
97
+ # httpcore
98
+ # uvicorn
99
+ hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
100
+ # via huggingface-hub
101
+ httpcore==1.0.9
102
+ # via httpx
103
+ httptools==0.6.4
104
+ # via uvicorn
105
+ httpx==0.28.1
106
+ # via
107
+ # fastapi
108
+ # fastapi-cloud-cli
109
+ # google-genai
110
+ huggingface-hub==0.33.4
111
+ # via
112
+ # open-clip-torch
113
+ # timm
114
+ # tokenizers
115
+ # transformers
116
+ idna==3.10
117
+ # via
118
+ # anyio
119
+ # email-validator
120
+ # httpx
121
+ # requests
122
+ imageio==2.37.0
123
+ # via
124
+ # moviepy
125
+ # scikit-image
126
+ imageio-ffmpeg==0.6.0
127
+ # via
128
+ # katna
129
+ # moviepy
130
+ imutils==0.5.4
131
+ # via katna
132
+ jinja2==3.1.6
133
+ # via
134
+ # fastapi
135
+ # torch
136
+ joblib==1.5.1
137
+ # via scikit-learn
138
+ katna==0.9.2
139
+ # via acmmm25-grand-challenge-geolocation
140
+ lazy-loader==0.4
141
+ # via scikit-image
142
+ llvmlite==0.44.0
143
+ # via numba
144
+ markdown-it-py==3.0.0
145
+ # via rich
146
+ markupsafe==3.0.2
147
+ # via jinja2
148
+ mdurl==0.1.2
149
+ # via markdown-it-py
150
+ more-itertools==10.7.0
151
+ # via openai-whisper
152
+ moviepy==2.2.1
153
+ # via acmmm25-grand-challenge-geolocation
154
+ mpmath==1.3.0
155
+ # via sympy
156
+ networkx==3.5
157
+ # via
158
+ # scikit-image
159
+ # torch
160
+ numba==0.61.2
161
+ # via openai-whisper
162
+ numpy==1.26.4
163
+ # via
164
+ # faiss-gpu-cu12
165
+ # imageio
166
+ # katna
167
+ # moviepy
168
+ # numba
169
+ # openai-whisper
170
+ # opencv-contrib-python
171
+ # opencv-python
172
+ # pandas
173
+ # scikit-image
174
+ # scikit-learn
175
+ # scipy
176
+ # tifffile
177
+ # torchvision
178
+ # transformers
179
+ nvidia-cublas-cu12==12.6.4.1
180
+ # via
181
+ # faiss-gpu-cu12
182
+ # nvidia-cudnn-cu12
183
+ # nvidia-cusolver-cu12
184
+ # torch
185
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
186
+ # via torch
187
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
188
+ # via torch
189
+ nvidia-cuda-runtime-cu12==12.6.77
190
+ # via
191
+ # faiss-gpu-cu12
192
+ # torch
193
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
194
+ # via torch
195
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
196
+ # via torch
197
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
198
+ # via torch
199
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
200
+ # via torch
201
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
202
+ # via torch
203
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
204
+ # via
205
+ # nvidia-cusolver-cu12
206
+ # torch
207
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
208
+ # via torch
209
+ nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
210
+ # via torch
211
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
212
+ # via
213
+ # nvidia-cufft-cu12
214
+ # nvidia-cusolver-cu12
215
+ # nvidia-cusparse-cu12
216
+ # torch
217
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
218
+ # via torch
219
+ open-clip-torch==2.32.0
220
+ # via acmmm25-grand-challenge-geolocation
221
+ openai-whisper==20250625
222
+ # via acmmm25-grand-challenge-geolocation
223
+ opencv-contrib-python==4.11.0.86
224
+ # via katna
225
+ opencv-python==4.11.0.86
226
+ # via acmmm25-grand-challenge-geolocation
227
+ packaging==25.0
228
+ # via
229
+ # faiss-gpu-cu12
230
+ # huggingface-hub
231
+ # lazy-loader
232
+ # scikit-image
233
+ # transformers
234
+ pandas==2.3.1
235
+ # via acmmm25-grand-challenge-geolocation
236
+ pillow==11.3.0
237
+ # via
238
+ # acmmm25-grand-challenge-geolocation
239
+ # imageio
240
+ # moviepy
241
+ # scikit-image
242
+ # torchvision
243
+ playwright==1.53.0
244
+ # via acmmm25-grand-challenge-geolocation
245
+ proglog==0.1.12
246
+ # via moviepy
247
+ proto-plus==1.26.1
248
+ # via
249
+ # google-api-core
250
+ # google-cloud-videointelligence
251
+ # google-cloud-vision
252
+ protobuf==6.31.1
253
+ # via
254
+ # google-api-core
255
+ # google-cloud-videointelligence
256
+ # google-cloud-vision
257
+ # googleapis-common-protos
258
+ # grpcio-status
259
+ # proto-plus
260
+ psutil==7.0.0
261
+ # via katna
262
+ pyasn1==0.6.1
263
+ # via
264
+ # pyasn1-modules
265
+ # rsa
266
+ pyasn1-modules==0.4.2
267
+ # via google-auth
268
+ pydantic==2.11.7
269
+ # via
270
+ # fastapi
271
+ # fastapi-cloud-cli
272
+ # google-genai
273
+ pydantic-core==2.33.2
274
+ # via pydantic
275
+ pyee==13.0.0
276
+ # via playwright
277
+ pygments==2.19.2
278
+ # via rich
279
+ pyproj==3.7.1
280
+ # via acmmm25-grand-challenge-geolocation
281
+ python-dateutil==2.9.0.post0
282
+ # via pandas
283
+ python-dotenv==1.1.1
284
+ # via
285
+ # acmmm25-grand-challenge-geolocation
286
+ # moviepy
287
+ # uvicorn
288
+ python-multipart==0.0.20
289
+ # via fastapi
290
+ pytz==2025.2
291
+ # via pandas
292
+ pyyaml==6.0.2
293
+ # via
294
+ # acmmm25-grand-challenge-geolocation
295
+ # huggingface-hub
296
+ # timm
297
+ # transformers
298
+ # uvicorn
299
+ regex==2024.11.6
300
+ # via
301
+ # open-clip-torch
302
+ # tiktoken
303
+ # transformers
304
+ requests==2.32.4
305
+ # via
306
+ # google-api-core
307
+ # google-genai
308
+ # huggingface-hub
309
+ # katna
310
+ # tiktoken
311
+ # transformers
312
+ rich==14.0.0
313
+ # via
314
+ # rich-toolkit
315
+ # typer
316
+ rich-toolkit==0.14.8
317
+ # via
318
+ # fastapi-cli
319
+ # fastapi-cloud-cli
320
+ rignore==0.6.2
321
+ # via fastapi-cloud-cli
322
+ rsa==4.9.1
323
+ # via google-auth
324
+ safetensors==0.5.3
325
+ # via
326
+ # open-clip-torch
327
+ # timm
328
+ # transformers
329
+ scikit-image==0.25.2
330
+ # via katna
331
+ scikit-learn==1.7.0
332
+ # via
333
+ # acmmm25-grand-challenge-geolocation
334
+ # katna
335
+ scipy==1.16.0
336
+ # via
337
+ # katna
338
+ # scikit-image
339
+ # scikit-learn
340
+ sentry-sdk==2.33.0
341
+ # via fastapi-cloud-cli
342
+ setuptools==80.9.0
343
+ # via
344
+ # torch
345
+ # triton
346
+ shellingham==1.5.4
347
+ # via typer
348
+ six==1.17.0
349
+ # via python-dateutil
350
+ sniffio==1.3.1
351
+ # via anyio
352
+ starlette==0.47.1
353
+ # via fastapi
354
+ sympy==1.14.0
355
+ # via torch
356
+ tenacity==8.5.0
357
+ # via google-genai
358
+ threadpoolctl==3.6.0
359
+ # via scikit-learn
360
+ tifffile==2025.6.11
361
+ # via scikit-image
362
+ tiktoken==0.9.0
363
+ # via openai-whisper
364
+ timm==1.0.17
365
+ # via open-clip-torch
366
+ tokenizers==0.21.2
367
+ # via transformers
368
+ torch==2.7.1
369
+ # via
370
+ # acmmm25-grand-challenge-geolocation
371
+ # open-clip-torch
372
+ # openai-whisper
373
+ # timm
374
+ # torchvision
375
+ torchvision==0.22.1
376
+ # via
377
+ # acmmm25-grand-challenge-geolocation
378
+ # open-clip-torch
379
+ # timm
380
+ tqdm==4.67.1
381
+ # via
382
+ # acmmm25-grand-challenge-geolocation
383
+ # huggingface-hub
384
+ # open-clip-torch
385
+ # openai-whisper
386
+ # proglog
387
+ # transformers
388
+ transformers==4.53.2
389
+ # via acmmm25-grand-challenge-geolocation
390
+ triton==3.3.1 ; (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'linux2'
391
+ # via
392
+ # openai-whisper
393
+ # torch
394
+ typer==0.16.0
395
+ # via
396
+ # fastapi-cli
397
+ # fastapi-cloud-cli
398
+ typing-extensions==4.14.1
399
+ # via
400
+ # anyio
401
+ # fastapi
402
+ # google-genai
403
+ # huggingface-hub
404
+ # pydantic
405
+ # pydantic-core
406
+ # pyee
407
+ # rich-toolkit
408
+ # starlette
409
+ # torch
410
+ # typer
411
+ # typing-inspection
412
+ typing-inspection==0.4.1
413
+ # via pydantic
414
+ tzdata==2025.2
415
+ # via pandas
416
+ urllib3==2.5.0
417
+ # via
418
+ # requests
419
+ # sentry-sdk
420
+ uvicorn==0.35.0
421
+ # via
422
+ # fastapi
423
+ # fastapi-cli
424
+ # fastapi-cloud-cli
425
+ uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
426
+ # via uvicorn
427
+ watchfiles==1.1.0
428
+ # via uvicorn
429
+ wcwidth==0.2.13
430
+ # via ftfy
431
+ websockets==15.0.1
432
+ # via
433
+ # google-genai
434
+ # uvicorn
src/data_processor.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import hashlib
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ import faiss
10
+ import torch
11
+ from PIL import Image
12
+ from torch import nn
13
+
14
+ from .prompt.fetch.content_fetch import fetch_links_to_json
15
+ from .prompt.fetch.satellite_fetch import fetch_satellite_image
16
+ from .prompt.preprocess.keyframe_extract import extract_and_save_keyframes
17
+ from .prompt.preprocess.video_transcribe import transcribe_video_directory
18
+ from .prompt.search.image_search import image_search_directory
19
+ from .prompt.search.index_search import save_results_to_json, search_index_directory
20
+ from .prompt.search.text_search import text_search_image, text_search_link
21
+
22
+ logger = logging.getLogger("uvicorn.error")
23
+
24
+
25
+ class DataProcessor:
26
+ def __init__(
27
+ self,
28
+ model: nn.Module,
29
+ input_dir: Path,
30
+ prompt_dir: Path,
31
+ cache_dir: Path,
32
+ image_dir: Path,
33
+ audio_dir: Path,
34
+ index_path: Path,
35
+ database_csv_path: Path,
36
+ device: torch.device,
37
+ ):
38
+ self.input_dir = input_dir
39
+ self.prompt_dir = prompt_dir
40
+ self.cache_dir = cache_dir
41
+ self.image_dir = image_dir
42
+ self.audio_dir = audio_dir
43
+ self.model = model
44
+ self.device = device
45
+ self.database_csv_path = database_csv_path
46
+
47
+ try:
48
+ self.index = faiss.read_index(str(index_path))
49
+ logger.info(f"✅ Successfully loaded FAISS index from: {index_path}")
50
+ except Exception as e:
51
+ raise RuntimeError(f"Failed to load FAISS index from {index_path}: {e}")
52
+
53
+ self.image_extension = {
54
+ ".jpg",
55
+ ".jpeg",
56
+ ".png",
57
+ ".bmp",
58
+ ".tiff",
59
+ ".tif",
60
+ ".webp",
61
+ }
62
+ self.video_extension = {
63
+ ".mp4",
64
+ ".avi",
65
+ ".mov",
66
+ ".mkv",
67
+ }
68
+
69
+ def __extract_keyframes(self):
70
+ """
71
+ Extract keyframes from all videos in the input directory.
72
+ Put all images and keyframes into the prompt directory.
73
+ """
74
+ output_dir = self.image_dir
75
+ os.makedirs(output_dir, exist_ok=True)
76
+
77
+ # Determine starting index based on existing files
78
+ current_files = list(output_dir.glob("image_*.*"))
79
+ idx = len(current_files)
80
+
81
+ # Process images
82
+ for file_name in os.listdir(self.input_dir):
83
+ file_path = os.path.join(self.input_dir, file_name)
84
+ if os.path.isfile(file_path) and file_name.lower().endswith(
85
+ tuple(self.image_extension)
86
+ ):
87
+ out_path = output_dir / f"image_{idx:03d}.jpg"
88
+ Image.open(file_path).convert("RGB").save(out_path)
89
+ idx += 1
90
+
91
+ # Process videos
92
+ for file_name in os.listdir(self.input_dir):
93
+ file_path = os.path.join(self.input_dir, file_name)
94
+ if os.path.isfile(file_path) and file_name.lower().endswith(
95
+ tuple(self.video_extension)
96
+ ):
97
+ if idx is None:
98
+ idx = 0
99
+ idx = extract_and_save_keyframes(
100
+ video_path=file_path, output_dir=str(output_dir), start_index=idx
101
+ )
102
+ logger.info(f"✅ Extracted keyframes and images to: {output_dir}")
103
+
104
+ def __transcribe_videos(self):
105
+ """
106
+ Transcribe all videos in the input directory.
107
+ Save transcripts into the prompt directory.
108
+ """
109
+ audio_dir = self.audio_dir
110
+ os.makedirs(audio_dir, exist_ok=True)
111
+
112
+ if audio_dir.is_dir() and any(audio_dir.iterdir()):
113
+ logger.info(f"🔄 Found existing transcripts in directory: {audio_dir}")
114
+ return
115
+
116
+ transcribe_video_directory(
117
+ video_dir=str(self.input_dir),
118
+ output_dir=str(audio_dir),
119
+ model_name="base", # Use the base Whisper model for transcription
120
+ )
121
+ logger.info(f"✅ Successfully transcribed videos to: {audio_dir}")
122
+
123
+ def __image_search(self):
124
+ """
125
+ Perform image search on all images in the input directory.
126
+ Save search results into the prompt directory.
127
+ """
128
+ image_dir = self.image_dir
129
+
130
+ if os.environ["IMGBB_API_KEY"] is None:
131
+ raise ValueError(
132
+ "IMGBB_API_KEY environment variable is not set or is None."
133
+ )
134
+ if os.environ["SCRAPINGDOG_API_KEY"] is None:
135
+ raise ValueError(
136
+ "SCRAPINGDOG_API_KEY environment variable is not set or is None."
137
+ )
138
+ image_search_directory(
139
+ directory=str(image_dir),
140
+ output_dir=str(self.prompt_dir),
141
+ filename="metadata.json",
142
+ imgbb_key=os.environ["IMGBB_API_KEY"],
143
+ scrapingdog_key=os.environ["SCRAPINGDOG_API_KEY"],
144
+ max_workers=4,
145
+ target_links=20,
146
+ )
147
+ logger.info(f"✅ Successfully performed image search on: {image_dir}")
148
+
149
+ def __text_search(self):
150
+ """
151
+ Perform text search with metadata to get related links.
152
+ """
153
+ query = ""
154
+ metadata_file = self.prompt_dir / "metadata.json"
155
+ if not metadata_file.exists():
156
+ query = ""
157
+ else:
158
+ with open(metadata_file, "r") as f:
159
+ metadata = json.load(f)
160
+ description = metadata.get("description", "")
161
+ location = metadata.get("location", "")
162
+ query = f"{description} in {location}".strip()
163
+
164
+ text_search_link(
165
+ query=query,
166
+ output_dir=str(self.prompt_dir),
167
+ filename="text_search.json",
168
+ num_results=10,
169
+ api_key=os.environ["GOOGLE_CLOUD_API_KEY"],
170
+ cx=os.environ["GOOGLE_CSE_CX"],
171
+ )
172
+
173
+ async def __fetch_related_link_content(
174
+ self, image_prediction: bool = True, text_prediction: bool = True
175
+ ):
176
+ """
177
+ Fetch related link content for all images and text in the prompt directory.
178
+ """
179
+
180
+ async def fetch_and_save_links(links, output_filename):
181
+ if links:
182
+ await fetch_links_to_json(
183
+ links=list(links),
184
+ output_path=str(self.prompt_dir / output_filename),
185
+ max_content_length=5000,
186
+ )
187
+ logger.info(
188
+ f"Fetched content for {len(links)} links into {output_filename}"
189
+ )
190
+
191
+ # Image links
192
+ image_links = set()
193
+ image_search_file = self.prompt_dir / "metadata.json"
194
+ if image_prediction:
195
+ if not image_search_file.exists():
196
+ self.__image_search()
197
+ with open(image_search_file, "r") as f:
198
+ image_search_data = json.load(f)
199
+ image_links.update(image_search_data.get("all_links", []))
200
+ logger.info(f"Found {len(image_links)} image links to fetch content from.")
201
+ await fetch_and_save_links(image_links, "image_search_content.json")
202
+
203
+ # Text links
204
+ text_links = set()
205
+ text_search_file = self.prompt_dir / "text_search.json"
206
+ if text_prediction:
207
+ if not text_search_file.exists():
208
+ self.__text_search()
209
+ with open(text_search_file, "r") as f:
210
+ text_search_data = json.load(f)
211
+ text_links.update(filter(None, text_search_data.get("links", [])))
212
+ logger.info(f"Found {len(text_links)} text links to fetch content from.")
213
+ await fetch_and_save_links(text_links, "text_search_content.json")
214
+
215
+ if not image_links and not text_links:
216
+ logger.info("No links found in image or text search results.")
217
+
218
+ def __index_search(self):
219
+ """
220
+ Perform FAISS index search on all images in the prompt directory.
221
+ Save search results into the report directory.
222
+ """
223
+ if not self.index:
224
+ raise RuntimeError(
225
+ "FAISS index is not loaded. Cannot perform index search."
226
+ )
227
+
228
+ output_path = self.prompt_dir / "index_search.json"
229
+ if output_path.exists():
230
+ logger.info(
231
+ f"Index search results already exist at {output_path}, skipping search."
232
+ )
233
+ return
234
+
235
+ if not os.path.exists(self.database_csv_path):
236
+ raise FileNotFoundError(
237
+ f"Database CSV file not found: {self.database_csv_path}"
238
+ )
239
+
240
+ candidates_gps, reverse_gps = search_index_directory(
241
+ model=self.model,
242
+ device=self.device,
243
+ index=self.index,
244
+ image_dir=str(self.image_dir),
245
+ database_csv_path=str(self.database_csv_path),
246
+ top_k=20,
247
+ max_elements=20,
248
+ )
249
+
250
+ save_results_to_json(candidates_gps, reverse_gps, str(output_path))
251
+ logger.info(
252
+ f"✅ Successfully performed index search. Results saved to: {output_path}"
253
+ )
254
+
255
+ async def __fetch_satellite_image_async(
256
+ self,
257
+ latitude: float,
258
+ longitude: float,
259
+ zoom: int,
260
+ output_path: Path,
261
+ ) -> None:
262
+ """
263
+ Asynchronously fetches a satellite image without blocking the event loop.
264
+
265
+ Runs the synchronous `fetch_satellite_image` function in a background thread.
266
+
267
+ Args:
268
+ latitude (float): Latitude of the location.
269
+ longitude (float): Longitude of the location.
270
+ zoom (int): Zoom level of the satellite image.
271
+ output_path (Path): Path to save the image file.
272
+ """
273
+ await asyncio.to_thread(
274
+ fetch_satellite_image,
275
+ latitude,
276
+ longitude,
277
+ zoom,
278
+ str(output_path),
279
+ )
280
+
281
+ async def __search_images_async(
282
+ self,
283
+ location: str,
284
+ num_images: int,
285
+ api_key: str | None,
286
+ cse_cx: str | None,
287
+ output_dir: Path,
288
+ image_id_offset: int,
289
+ ) -> list[str]:
290
+ """
291
+ Asynchronously searches for images based on a text location query.
292
+
293
+ Args:
294
+ location (str): Text location to search.
295
+ num_images (int): Number of images to fetch.
296
+ api_key (str): Google Cloud API key.
297
+ cse_cx (str): Google Custom Search Engine ID.
298
+ output_dir (Path): Directory where images will be saved.
299
+ image_id_offset (int): Offset for image filenames.
300
+
301
+ Returns:
302
+ Any: The result of `text_search_image`, if it returns a value.
303
+ """
304
+ return await asyncio.to_thread(
305
+ text_search_image,
306
+ location,
307
+ num_images,
308
+ api_key,
309
+ cse_cx,
310
+ str(output_dir),
311
+ image_id_offset,
312
+ )
313
+
314
+ def __compute_sha256(self, filepath: Path) -> str:
315
+ """
316
+ Compute the SHA-256 hash of a file.
317
+ """
318
+ if not filepath.is_file():
319
+ raise ValueError(f"File does not exist: {filepath}")
320
+
321
+ sha256 = hashlib.sha256()
322
+ with open(filepath, "rb") as f:
323
+ for chunk in iter(lambda: f.read(4096), b""):
324
+ sha256.update(chunk)
325
+ return sha256.hexdigest()
326
+
327
+ def __compare_directories(self, dir1: Path, dir2: Path) -> bool:
328
+ """
329
+ Compare two directories to check if they contain the same files with identical content.
330
+ Args:
331
+ dir1 (Path): First directory to compare.
332
+ dir2 (Path): Second directory to compare.
333
+ Returns:
334
+ bool: True if both directories contain the same files with identical content, False otherwise.
335
+ """
336
+ if not dir1.is_dir() or not dir2.is_dir():
337
+ return False
338
+
339
+ files1 = sorted(p for p in dir1.iterdir() if p.is_file())
340
+ files2 = sorted(p for p in dir2.iterdir() if p.is_file())
341
+
342
+ # Check if filenames match exactly
343
+ names1 = {p.name for p in files1}
344
+ names2 = {p.name for p in files2}
345
+ if names1 != names2:
346
+ return False
347
+
348
+ # Compare each matching file
349
+ for filename in names1:
350
+ path1 = dir1 / filename
351
+ path2 = dir2 / filename
352
+
353
+ # Skip directories
354
+ if not path1.is_file() or not path2.is_file():
355
+ continue
356
+
357
+ hash1 = self.__compute_sha256(path1)
358
+ hash2 = self.__compute_sha256(path2)
359
+
360
+ if hash1 != hash2:
361
+ return False # Found mismatch
362
+ return True # All matching files are identical
363
+
364
+ def __copy_directory(self, src: Path, dest: Path):
365
+ """
366
+ Recursively copy all files from src to dest.
367
+ """
368
+ if not src.is_dir():
369
+ raise ValueError(f"Source path is not a directory: {src}")
370
+
371
+ # Delete everything in dest first
372
+ if dest.exists():
373
+ for item in dest.iterdir():
374
+ if item.is_file() or item.is_symlink():
375
+ item.unlink()
376
+ elif item.is_dir():
377
+ shutil.rmtree(item)
378
+
379
+ # Ensure dest exists
380
+ dest.mkdir(parents=True, exist_ok=True)
381
+
382
+ for item in src.iterdir():
383
+ if item.is_dir():
384
+ self.__copy_directory(item, dest / item.name)
385
+ else:
386
+ dest_file = dest / item.name
387
+ if not dest_file.exists() or not self.__compare_directories(
388
+ item, dest_file
389
+ ):
390
+ shutil.copy2(item, dest_file)
391
+
392
+ async def preprocess_input_data(
393
+ self,
394
+ image_prediction: bool = True,
395
+ text_prediction: bool = True,
396
+ ):
397
+ """
398
+ Preprocess all input data:
399
+ - Extract keyframes from videos.
400
+ - Transcribe videos.
401
+ - Fetch related link content from images.
402
+ Save images and extracted keyframes into the output directory
403
+ """
404
+ os.makedirs(self.prompt_dir, exist_ok=True)
405
+ os.makedirs(self.cache_dir, exist_ok=True)
406
+
407
+ cache_dir_input = self.cache_dir / "input_data"
408
+ cache_dir_prompt = self.cache_dir / "prompt_data"
409
+ if self.__compare_directories(self.input_dir, cache_dir_input):
410
+ logger.info("Input data already processed, skipping...")
411
+ self.__copy_directory(cache_dir_prompt, self.prompt_dir)
412
+ return
413
+ else:
414
+ logger.info("Processing input data...")
415
+
416
+ metadata_dest = self.prompt_dir / "metadata.json"
417
+ if not metadata_dest.exists():
418
+ for file in os.listdir(self.input_dir):
419
+ if file.endswith(".json"):
420
+ file_path = os.path.join(self.input_dir, file)
421
+ with open(file_path, "r") as src_file:
422
+ with open(metadata_dest, "w") as dest_file:
423
+ dest_file.write(src_file.read())
424
+ break
425
+
426
+ self.__extract_keyframes()
427
+ self.__transcribe_videos()
428
+ await self.__fetch_related_link_content(
429
+ image_prediction=image_prediction, text_prediction=text_prediction
430
+ )
431
+ self.__index_search()
432
+
433
+ logger.info("✅ Preprocessing completed")
434
+ logger.info(f"Saving processed data to cache directory: {self.cache_dir}")
435
+ self.__copy_directory(self.input_dir, cache_dir_input)
436
+ self.__copy_directory(self.prompt_dir, cache_dir_prompt)
437
+
438
+ async def prepare_location_images(
439
+ self,
440
+ prediction: dict,
441
+ image_prediction: bool = True,
442
+ text_prediction: bool = True,
443
+ ) -> int:
444
+ """
445
+ Prepare verification data from the prediction with parallel fetching.
446
+
447
+ Args:
448
+ prediction (dict): Prediction dictionary with latitude, longitude, location, reason, and metadata
449
+ image_prediction (bool): Whether to include original images in verification
450
+ text_prediction (bool): Whether to include text-based verification
451
+
452
+ Returns:
453
+ int: Satellite image ID for reference in prompts
454
+ """
455
+ image_dir = self.image_dir
456
+ satellite_image_id = len(list(self.image_dir.glob("image_*.*")))
457
+
458
+ # Execute both operations in parallel
459
+ logger.info("🔄 Fetching satellite image and location images in parallel...")
460
+
461
+ # Ensure required API keys are present
462
+ if not os.environ.get("GOOGLE_CLOUD_API_KEY"):
463
+ raise ValueError(
464
+ "GOOGLE_CLOUD_API_KEY environment variable is not set or is None."
465
+ )
466
+ if not os.environ.get("GOOGLE_CSE_CX"):
467
+ raise ValueError(
468
+ "GOOGLE_CSE_CX environment variable is not set or is None."
469
+ )
470
+
471
+ await asyncio.gather(
472
+ self.__fetch_satellite_image_async(
473
+ prediction["latitude"],
474
+ prediction["longitude"],
475
+ zoom=200,
476
+ output_path=image_dir / f"image_{satellite_image_id:03d}.jpg",
477
+ ),
478
+ self.__search_images_async(
479
+ location=prediction["location"],
480
+ num_images=5,
481
+ api_key=os.environ["GOOGLE_CLOUD_API_KEY"],
482
+ cse_cx=os.environ["GOOGLE_CSE_CX"],
483
+ output_dir=image_dir,
484
+ image_id_offset=satellite_image_id + 1,
485
+ ),
486
+ )
487
+ logger.info("✅ Verification data preparation completed")
488
+ return satellite_image_id
src/g3/G3.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import cast
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
5
+
6
+ from .locationencoder import LocationEncoder
7
+
8
+ class G3(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ device: str,
12
+ positional_encoding_type: str = "sh",
13
+ neural_network_type: str = "siren",
14
+ hparams: dict | None = None,
15
+ ):
16
+ super(G3, self).__init__()
17
+ self.device = device
18
+
19
+ clip_model = cast(CLIPModel, CLIPModel.from_pretrained("openai/clip-vit-large-patch14"))
20
+ self.vision_model = clip_model.vision_model
21
+ self.text_model = clip_model.text_model
22
+ self.vision_processor = cast(CLIPImageProcessor, CLIPImageProcessor.from_pretrained(
23
+ "openai/clip-vit-large-patch14"
24
+ ))
25
+ self.text_processor = cast(CLIPTokenizer, CLIPTokenizer.from_pretrained(
26
+ "openai/clip-vit-large-patch14"
27
+ ))
28
+ self.vision_projection = clip_model.visual_projection
29
+ self.text_projection = clip_model.text_projection
30
+
31
+ self.logit_scale1 = nn.Parameter(torch.tensor(3.99))
32
+ self.logit_scale2 = nn.Parameter(torch.tensor(3.99))
33
+ self.logit_scale3 = nn.Parameter(torch.tensor(3.99))
34
+
35
+ self.location_encoder = LocationEncoder(
36
+ positional_encoding_type=positional_encoding_type.split("_")[0],
37
+ neural_network_type=neural_network_type,
38
+ hparams=hparams,
39
+ device=device,
40
+ ) # output batch_size, 3, 512
41
+ self.vision_projection_else_1 = nn.Sequential(
42
+ nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
43
+ )
44
+ self.text_projection_else = nn.Sequential(
45
+ nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
46
+ )
47
+
48
+ self.vision_projection_else_2 = nn.Sequential(
49
+ nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
50
+ )
51
+ self.location_projection_else = nn.Sequential(
52
+ nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 768)
53
+ )
54
+ # output_dim = 512 if hparams is None else hparams["output_dim"]
55
+ # self.location_projection_else = nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, 768))
56
+
57
+ # freeze CLIP
58
+ self.vision_model.requires_grad_(False)
59
+ self.vision_projection.requires_grad_(False)
60
+ self.text_model.requires_grad_(False)
61
+ self.text_projection.requires_grad_(False)
62
+
63
+ def forward(self, images, texts, longitude, latitude):
64
+ vision_output = self.vision_model(images)[1]
65
+ text_output = self.text_model(**texts)[1]
66
+ image_embeds = self.vision_projection(vision_output)
67
+ text_embeds = self.text_projection(text_output) # batch_size, 512
68
+ this_batch_locations = torch.stack((latitude, longitude), dim=1)
69
+ location_embeds = self.location_encoder(this_batch_locations)
70
+
71
+ # phase _1
72
+ image_embeds_1 = self.vision_projection_else_1(image_embeds)
73
+ text_embeds_1 = self.text_projection_else(
74
+ text_embeds.reshape(text_embeds.shape[0], -1)
75
+ )
76
+
77
+ # normalized features
78
+ image_embeds_1 = image_embeds_1 / image_embeds_1.norm(p=2, dim=-1, keepdim=True)
79
+ text_embeds_1 = text_embeds_1 / text_embeds_1.norm(p=2, dim=-1, keepdim=True)
80
+
81
+ # image with texts
82
+ logit_scale = self.logit_scale1.exp()
83
+ logits_per_texts_with_images = (
84
+ torch.matmul(text_embeds_1, image_embeds_1.t()) * logit_scale
85
+ )
86
+ logits_per_images_with_texts = logits_per_texts_with_images.t()
87
+ loss_phase_1 = self.clip_loss(logits_per_texts_with_images)
88
+
89
+ # phase _2
90
+ image_embeds_2 = self.vision_projection_else_2(image_embeds)
91
+ location_embeds_2 = self.location_projection_else(
92
+ location_embeds.reshape(location_embeds.shape[0], -1)
93
+ )
94
+
95
+ # normalized features
96
+ image_embeds_2 = image_embeds_2 / image_embeds_2.norm(p=2, dim=-1, keepdim=True)
97
+ location_embeds_2 = location_embeds_2 / location_embeds_2.norm(
98
+ p=2, dim=-1, keepdim=True
99
+ )
100
+
101
+ # image with location
102
+ logit_scale = self.logit_scale2.exp()
103
+ logits_per_locations_with_images = (
104
+ torch.matmul(location_embeds_2, image_embeds_2.t()) * logit_scale
105
+ )
106
+ logits_per_images_with_locations = logits_per_locations_with_images.t()
107
+ loss_phase_2 = None
108
+ loss_phase_2 = self.clip_loss(logits_per_locations_with_images)
109
+
110
+ loss = loss_phase_1 + loss_phase_2
111
+
112
+ return {
113
+ "logits_per_texts_with_images": logits_per_texts_with_images,
114
+ "logits_per_images_with_texts": logits_per_images_with_texts,
115
+ "logits_per_locations_with_images": logits_per_locations_with_images,
116
+ "logits_per_images_with_locations": logits_per_images_with_locations,
117
+ "logits_per_locations_with_texts": None,
118
+ "logits_per_texts_with_locations": None,
119
+ "loss": loss,
120
+ "vision_output": vision_output,
121
+ "text_output": text_output,
122
+ "image_embeds": image_embeds,
123
+ "text_embeds": text_embeds,
124
+ }
125
+
126
+ def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor:
127
+ return nn.functional.cross_entropy(
128
+ logits, torch.arange(len(logits), device=logits.device)
129
+ )
130
+
131
+ def clip_loss(self, similarity: torch.Tensor) -> torch.Tensor:
132
+ caption_loss = self.contrastive_loss(similarity)
133
+ image_loss = self.contrastive_loss(similarity.t())
134
+ return (caption_loss + image_loss) / 2.0
src/g3/dataset.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import tarfile
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from typing import Callable, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torchvision.transforms as T
12
+ import transformers
13
+ from PIL import Image, ImageFile
14
+ from torch.utils.data import DataLoader, get_worker_info
15
+ from torchvision.datasets import VisionDataset
16
+ from torchvision.io import ImageReadMode, read_image
17
+ from tqdm import tqdm
18
+ from transformers import (
19
+ CLIPImageProcessor,
20
+ CLIPModel,
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ CLIPVisionModel,
24
+ )
25
+
26
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow truncated images to be loaded
27
+
28
+ from io import BytesIO
29
+ from typing import Any, Dict, Iterator, Optional, Tuple
30
+
31
+ import torch
32
+ import torchvision.transforms as T
33
+ from datasets import load_dataset
34
+ from huggingface_hub import login
35
+ from PIL import Image
36
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
37
+
38
+ __all__ = [
39
+ "MP16StreamingDataset",
40
+ "mp16_collate",
41
+ ]
42
+
43
+
44
+ class MP16StreamingDataset(IterableDataset):
45
+ """Stream **MP‑16** samples from the HuggingFace Hub and yield a simple
46
+ tuple per example::
47
+
48
+ (image, text, longitude, latitude)
49
+
50
+ * **image** – either a tensor (``C×H×W``) if *vision_processor* is set or if
51
+ the fallback transform is used, otherwise a PIL image.
52
+ * **text** – caption string (either provided by the dataset or generated
53
+ from location fields).
54
+ * **longitude**, **latitude** – floats.
55
+
56
+ The class is an :class:`torch.utils.data.IterableDataset`, so wrap it in a
57
+ :class:`~torch.utils.data.DataLoader` for batching.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ repo_id: str = "tduongvn/MP16-Pro-shards",
63
+ split: str = "train",
64
+ vision_processor: Optional[Any] = None,
65
+ shuffle_buffer: int = 10_000,
66
+ HF_TOKEN: Optional[str] = None,
67
+ ) -> None:
68
+ super().__init__()
69
+ self.repo_id = repo_id
70
+ self.split = split
71
+ self.vision_processor = vision_processor
72
+ self.shuffle_buffer = shuffle_buffer
73
+ self.HF_TOKEN = HF_TOKEN
74
+
75
+ # Base transform when we *don't* have a fancy processor
76
+ self.fallback_transform = T.Compose(
77
+ [
78
+ T.RandomHorizontalFlip(),
79
+ T.RandomResizedCrop(size=224),
80
+ T.ToTensor(),
81
+ ]
82
+ )
83
+
84
+ # Prepare an initial dataset iterator for the main process
85
+ self._base_iter = self._new_iterator()
86
+
87
+ # ──────────────────────────────────────────────────────────────────────────
88
+ # Internals ─┘
89
+
90
+ def _new_iterator(self):
91
+ if self.HF_TOKEN is not None:
92
+ login(token=self.HF_TOKEN)
93
+ return (
94
+ load_dataset(self.repo_id, split=self.split, streaming=True)
95
+ .shuffle(buffer_size=self.shuffle_buffer)
96
+ .__iter__()
97
+ )
98
+
99
+ def _decode_image(self, img_bytes):
100
+ """bytes → PIL.Image or tensor (if processor is set)."""
101
+ img = Image.open(BytesIO(img_bytes)).convert("RGB")
102
+ if self.vision_processor is not None:
103
+ return self.vision_processor(images=img, return_tensors="pt")[
104
+ "pixel_values"
105
+ ].squeeze(0)
106
+ return self.fallback_transform(img)
107
+
108
+ def _caption(self, ex_json: Dict[str, Any]) -> str:
109
+ parts = [ex_json.get(k) for k in ("city", "state", "country") if ex_json.get(k)]
110
+ return "A street view photo taken in " + ", ".join(parts)
111
+
112
+ # ──────────────────────────────────────────────────────────────────────────
113
+ # IterableDataset API ─┘
114
+
115
+ def __iter__(self) -> Iterator[Tuple[Any, str, float, float]]:
116
+ # Each DataLoader worker gets its own iterator to avoid state clashes.
117
+ worker = get_worker_info()
118
+ iterator = self._new_iterator() if worker is not None else self._base_iter
119
+
120
+ for ex in iterator:
121
+ # Dataset structure: {'jpg': <PIL or bytes>, 'json': {...}, ...}
122
+ img_field = ex["jpg"]
123
+ if isinstance(img_field, Image.Image):
124
+ img = img_field.convert("RGB")
125
+ if self.vision_processor is not None:
126
+ img = self.vision_processor(images=img, return_tensors="pt")[
127
+ "pixel_values"
128
+ ].squeeze(0)
129
+ else:
130
+ img = self.fallback_transform(img)
131
+ else: # bytes
132
+ img = self._decode_image(img_field)
133
+
134
+ meta = ex["json"] if "json" in ex else {}
135
+ lon = float(meta.get("lon", meta.get("LON")))
136
+ lat = float(meta.get("lat", meta.get("LAT")))
137
+ text = meta.get("text") or self._caption(meta)
138
+
139
+ yield img, text, lon, lat
140
+
141
+ # No __len__ – this is a stream.
142
+
143
+
144
+ # ─────────────────────────────────────────────────────────────────────────────
145
+ # Collate ─┘
146
+
147
+
148
+ def make_mp16_collate(text_processor):
149
+ def collate(batch):
150
+ images, texts, lons, lats = zip(*batch)
151
+
152
+ images = torch.stack(images) # (B, C, H, W)
153
+
154
+ token_out = text_processor(
155
+ list(texts),
156
+ padding="longest",
157
+ truncation=True,
158
+ max_length=77,
159
+ return_tensors="pt",
160
+ )
161
+
162
+ lons = torch.tensor(lons, dtype=torch.float32)
163
+ lats = torch.tensor(lats, dtype=torch.float32)
164
+
165
+ return images, token_out, lons, lats
166
+
167
+ return collate
168
+
169
+
170
+ class MP16Dataset(VisionDataset):
171
+ def __init__(
172
+ self,
173
+ root_path="data/mp16/",
174
+ text_data_path="MP16_Pro_places365.csv",
175
+ image_data_path="mp-16-images.tar",
176
+ member_info_path="tar_index.pkl",
177
+ vision_processor=None,
178
+ text_processor=None,
179
+ ):
180
+ super().__init__(self)
181
+ self.root_path = root_path
182
+ self.text_data_path = text_data_path
183
+ self.image_data_path = image_data_path
184
+ self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
185
+ self.text_data["IMG_ID"] = self.text_data["IMG_ID"].apply(
186
+ lambda x: x.replace("/", "_")
187
+ )
188
+ # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
189
+ print("read text data success")
190
+ worker = get_worker_info()
191
+ worker = worker.id if worker else None
192
+ self.tar_obj = {worker: tarfile.open(os.path.join(root_path, image_data_path))}
193
+ # self.tar = tarfile.open(os.path.join(root_path, image_data_path))
194
+
195
+ if os.path.exists(os.path.join(self.root_path, member_info_path)):
196
+ with open(os.path.join(self.root_path, member_info_path), "rb") as f:
197
+ self.tar_index = pickle.load(f)
198
+ all_image_names = list(self.tar_index.keys())
199
+ print("load tar index success")
200
+ else:
201
+ print("no exist tar index success, need building...")
202
+ self.tar_index = {}
203
+ all_image_names = []
204
+ for member in tqdm(self.tar_obj[worker]):
205
+ if member.name.endswith(".jpg") and member.size > 5120:
206
+ self.tar_index[member.name.split("/")[1]] = member
207
+ all_image_names.append(member.name.split("/")[1])
208
+ print("tar index buidling success")
209
+ with open(os.path.join(self.root_path, member_info_path), "wb") as f:
210
+ pickle.dump(self.tar_index, f)
211
+ all_image_names = set(all_image_names)
212
+
213
+ self.text_data = self.text_data[self.text_data["country"].notnull()]
214
+ self.text_data = self.text_data[self.text_data["IMG_ID"].isin(all_image_names)]
215
+ print("data columns: ", self.text_data.shape[0])
216
+
217
+ # location from str to float
218
+ self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
219
+ self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
220
+ print("location from str to float success")
221
+
222
+ # image transform
223
+ self.transform = T.Resize(size=(512, 512))
224
+ self.transform_totensor = T.ToTensor()
225
+
226
+ self.vision_processor = vision_processor
227
+ self.text_processor = text_processor
228
+
229
+ # Define the contrast transforms here
230
+ self.contrast_transforms = T.Compose(
231
+ [
232
+ T.RandomHorizontalFlip(),
233
+ T.RandomResizedCrop(size=224),
234
+ T.RandomApply(
235
+ [
236
+ T.ColorJitter(
237
+ brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1
238
+ )
239
+ ],
240
+ p=0.8,
241
+ ),
242
+ T.RandomGrayscale(p=0.2),
243
+ T.GaussianBlur(kernel_size=9),
244
+ T.ToTensor(),
245
+ # T.Normalize((0.5,), (0.5,))
246
+ ]
247
+ )
248
+
249
+ # self.text_data.to_csv('/data/mp-16/MP16_Pro_filtered.csv', index=False)
250
+
251
+ def caption_generation(self, row):
252
+ pass
253
+
254
+ def __getitem__(self, index):
255
+ image_path = self.text_data.iloc[index]["IMG_ID"]
256
+ text = ""
257
+ neighbourhood, city, county, state, region, country, continent = (
258
+ self.text_data.iloc[index][
259
+ [
260
+ "neighbourhood",
261
+ "city",
262
+ "county",
263
+ "state",
264
+ "region",
265
+ "country",
266
+ "continent",
267
+ ]
268
+ ]
269
+ )
270
+ # location_elements = [element for element in [neighbourhood, city, state, country] if element is not np.nan and str(element) != 'nan']
271
+ location_elements = [
272
+ element
273
+ for element in [city, state, country]
274
+ if element is not np.nan and str(element) != "nan"
275
+ ]
276
+ text = "A street view photo taken in " + ", ".join(location_elements)
277
+
278
+ longitude = self.text_data.iloc[index]["LON"]
279
+ latitude = self.text_data.iloc[index]["LAT"]
280
+ # read the image from self.tar
281
+ worker = get_worker_info()
282
+ worker = worker.id if worker else None
283
+ if worker not in self.tar_obj:
284
+ self.tar_obj[worker] = tarfile.open(
285
+ os.path.join(self.root_path, self.image_data_path)
286
+ )
287
+ image = self.tar_obj[worker].extractfile(self.tar_index[image_path])
288
+ image = Image.open(image)
289
+
290
+ if image.mode != "RGB":
291
+ image = image.convert("RGB")
292
+
293
+ if self.vision_processor:
294
+ image = self.vision_processor(images=image, return_tensors="pt")[
295
+ "pixel_values"
296
+ ].reshape(3, 224, 224)
297
+
298
+ return image, text, longitude, latitude
299
+
300
+ def __len__(self):
301
+ return len(self.text_data)
302
+
303
+
304
+ class im2gps3kDataset(VisionDataset):
305
+ def __init__(
306
+ self,
307
+ root_path="./data/im2gps3k",
308
+ text_data_path="im2gps3k_places365.csv",
309
+ image_data_path="images/",
310
+ vision_processor=None,
311
+ text_processor=None,
312
+ ):
313
+ super().__init__(self)
314
+ print("start loading im2gps...")
315
+ self.root_path = root_path
316
+ self.text_data_path = text_data_path
317
+ self.image_data_path = image_data_path
318
+ self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
319
+ # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
320
+ print("read text data success")
321
+
322
+ # location from str to float
323
+ self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
324
+ self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
325
+ print("location from str to float success")
326
+
327
+ self.vision_processor = vision_processor
328
+ self.text_processor = text_processor
329
+
330
+ self.tencrop = T.TenCrop(224)
331
+
332
+ def __getitem__(self, index):
333
+ image_path = self.text_data.iloc[index]["IMG_ID"]
334
+ text = image_path
335
+
336
+ longitude = self.text_data.iloc[index]["LON"]
337
+ latitude = self.text_data.iloc[index]["LAT"]
338
+
339
+ image = Image.open(
340
+ os.path.join(self.root_path, self.image_data_path, image_path)
341
+ )
342
+
343
+ if image.mode != "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ # image = self.tencrop(image) # for tencrop
347
+
348
+ if self.vision_processor:
349
+ image = self.vision_processor(images=image, return_tensors="pt")[
350
+ "pixel_values"
351
+ ].reshape(-1, 224, 224)
352
+
353
+ return image, text, longitude, latitude
354
+
355
+ def __len__(self):
356
+ return len(self.text_data)
357
+
358
+
359
+ class yfcc4kDataset(VisionDataset):
360
+ def __init__(
361
+ self,
362
+ root_path="./data/yfcc4k",
363
+ text_data_path="yfcc4k_places365.csv",
364
+ image_data_path="images/",
365
+ vision_processor=None,
366
+ text_processor=None,
367
+ ):
368
+ super().__init__(self)
369
+ print("start loading yfcc4k...")
370
+ self.root_path = root_path
371
+ self.text_data_path = text_data_path
372
+ self.image_data_path = image_data_path
373
+ self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
374
+ # self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
375
+ print("read text data success")
376
+
377
+ # location from str to float
378
+ self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
379
+ self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
380
+ print("location from str to float success")
381
+
382
+ self.vision_processor = vision_processor
383
+ self.text_processor = text_processor
384
+
385
+ def __getitem__(self, index):
386
+ image_path = self.text_data.iloc[index]["IMG_ID"]
387
+ text = image_path
388
+
389
+ longitude = self.text_data.iloc[index]["LON"]
390
+ latitude = self.text_data.iloc[index]["LAT"]
391
+
392
+ image = Image.open(
393
+ os.path.join(self.root_path, self.image_data_path, image_path)
394
+ )
395
+
396
+ if image.mode != "RGB":
397
+ image = image.convert("RGB")
398
+
399
+ if self.vision_processor:
400
+ image = self.vision_processor(images=image, return_tensors="pt")[
401
+ "pixel_values"
402
+ ].reshape(-1, 224, 224)
403
+
404
+ return image, text, longitude, latitude
405
+
406
+ def __len__(self):
407
+ return len(self.text_data)
src/g3/hparams.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sh_siren:
2
+ # legendre_polys: 30
3
+ # harmonics_calculation: analytic
4
+ # hidden_dim: 512
5
+ # num_layers: 3
6
+ # lr: 7.887855321604208e-05
7
+ # wd: 1.3475466222160537e-06
8
+
9
+ sh_siren:
10
+ legendre_polys: 40
11
+ harmonics_calculation: analytic
12
+ hidden_dim: 512
13
+ output_dim: 256
14
+ num_layers: 2
15
+ lr: 0.0001
16
+ wd: 0.01
17
+
18
+
19
+ projection_eep_rffmlp:
20
+ projection: eep
21
+ sigma:
22
+ - 1
23
+ - 16
24
+ - 256
25
+ hidden_dim: 1024
26
+ lr: 0.00003
27
+ wd: 0.000001
28
+
29
+ projection_mercator_rffmlp:
30
+ projection: mercator
31
+ sigma:
32
+ - 1
33
+ - 16
34
+ - 256
35
+ hidden_dim: 1024
36
+ lr: 0.00003
37
+ wd: 0.000001
38
+
39
+ projection_ecef_rffmlp:
40
+ projection: ecef
41
+ sigma:
42
+ - 1
43
+ - 16
44
+ - 256
45
+ hidden_dim: 1024
46
+ lr: 0.00003
47
+ wd: 0.000001
48
+
49
+ projection_eep_rffmlp:
50
+ projection: eep
51
+ sigma:
52
+ - 1
53
+ - 16
54
+ - 256
55
+ hidden_dim: 1024
56
+ lr: 0.00003
57
+ wd: 0.000001
src/g3/locationencoder.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .nn.mlp import MLP
4
+ from .nn.rff_mlp import RFFMLP
5
+ from .nn.siren import SirenNet
6
+ from .pe.projection import Projection
7
+ from .pe.projection_rff import ProjectionRFF
8
+ from .pe.spherical_harmonics import SphericalHarmonics
9
+
10
+
11
+ def get_positional_encoding(positional_encoding_type, hparams, device="cuda"):
12
+ """
13
+ Returns a positional encoding module based on the specified encoding type.
14
+
15
+ Args:
16
+ encoding_type (str): The type of positional encoding to use. Options are 'rff', 'siren', 'sh', 'capsule'.
17
+ input_dim (int): The input dimension for the positional encoding.
18
+ output_dim (int): The output dimension for the positional encoding.
19
+ hparams: Additional arguments for specific encoding types.
20
+
21
+ Returns:
22
+ nn.Module: The positional encoding module.
23
+ """
24
+ if positional_encoding_type == "projectionrff":
25
+ return ProjectionRFF(
26
+ projection=hparams["projection"],
27
+ sigma=hparams["sigma"],
28
+ hparams=hparams,
29
+ device=device,
30
+ )
31
+ elif positional_encoding_type == "projection":
32
+ return Projection(
33
+ projection=hparams["projection"], hparams=hparams, device=device
34
+ )
35
+ elif positional_encoding_type == "sh":
36
+ return SphericalHarmonics(
37
+ legendre_polys=hparams["legendre_polys"],
38
+ harmonics_calculation=hparams["harmonics_calculation"],
39
+ hparams=hparams,
40
+ device=device,
41
+ )
42
+ else:
43
+ raise ValueError(f"Unsupported encoding type: {positional_encoding_type}")
44
+
45
+
46
+ def get_neural_network(
47
+ neural_network_type: str,
48
+ input_dim: int,
49
+ hparams: dict,
50
+ device="cuda",
51
+ ):
52
+ """
53
+ Returns a neural network module based on the specified network type.
54
+
55
+ Args:
56
+ neural_network_type (str): The type of neural network to use. Options are 'siren'.
57
+ input_dim (int): The input dimension for the neural network.
58
+ output_dim (int): The output dimension for the neural network.
59
+ hparams: Additional arguments for specific network types.
60
+
61
+ Returns:
62
+ nn.Module: The neural network module.
63
+ """
64
+ if neural_network_type == "siren":
65
+ return SirenNet(
66
+ input_dim=input_dim,
67
+ output_dim=hparams["output_dim"],
68
+ hidden_dim=hparams["hidden_dim"],
69
+ num_layers=hparams["num_layers"],
70
+ hparams=hparams,
71
+ device=device,
72
+ )
73
+ elif neural_network_type == "mlp":
74
+ return MLP(
75
+ input_dim=input_dim,
76
+ hidden_dim=hparams["hidden_dim"],
77
+ hparams=hparams,
78
+ device=device,
79
+ )
80
+ elif neural_network_type == "rffmlp":
81
+ return RFFMLP(
82
+ input_dim=input_dim,
83
+ hidden_dim=hparams["hidden_dim"],
84
+ sigma=hparams["sigma"],
85
+ hparams=hparams,
86
+ device=device,
87
+ )
88
+ else:
89
+ raise ValueError(f"Unsupported network type: {neural_network_type}")
90
+
91
+
92
+ class LocationEncoder(nn.Module):
93
+ def __init__(
94
+ self,
95
+ positional_encoding_type: str = "sh",
96
+ neural_network_type: str = "siren",
97
+ hparams: dict | None = None,
98
+ device: str = "cuda",
99
+ ):
100
+ super().__init__()
101
+ self.device = device
102
+
103
+ self.position_encoder = get_positional_encoding(
104
+ positional_encoding_type=positional_encoding_type,
105
+ hparams=hparams,
106
+ device=device,
107
+ )
108
+
109
+ if hparams is None:
110
+ hparams = {}
111
+
112
+ self.neural_network = nn.ModuleList(
113
+ [
114
+ get_neural_network(
115
+ neural_network_type, input_dim=dim, hparams=hparams, device=device
116
+ )
117
+ for dim in self.position_encoder.embedding_dim
118
+ ]
119
+ )
120
+
121
+ def forward(self, x):
122
+ embedding = self.position_encoder(x)
123
+
124
+ if embedding.ndim == 2:
125
+ # If the embedding is (batch, n), we need to add a dimension
126
+ embedding = embedding.unsqueeze(0)
127
+
128
+ location_features = torch.zeros(embedding.shape[1], 512).to(self.device)
129
+
130
+ for nn, e in zip(self.neural_network, embedding):
131
+ location_features += nn(e)
132
+
133
+ return location_features
src/g3/nn/mlp.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class MLP(nn.Module):
4
+ """Multi-layer perceptron (MLP) with batch normalization and ReLU activation."""
5
+
6
+ def __init__(self, input_dim=512, hidden_dim=1024, output_dim=512, hparams=None, device='cuda'):
7
+ super(MLP, self).__init__()
8
+ self.device = device
9
+ self.capsule = nn.Sequential(nn.Linear(input_dim, hidden_dim),
10
+ nn.ReLU(),
11
+ nn.Linear(hidden_dim, hidden_dim),
12
+ nn.ReLU(),
13
+ nn.Linear(hidden_dim, hidden_dim),
14
+ nn.ReLU())
15
+ self.head = nn.Sequential(nn.Linear(hidden_dim, output_dim))
16
+
17
+ def forward(self, x):
18
+ x = self.capsule(x)
19
+ x = self.head(x)
20
+ return x
src/g3/nn/rff_mlp.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from ..rff.layers import GaussianEncoding
4
+
5
+ class LocationEncoderCapsule(nn.Module):
6
+ def __init__(self, input_dim=2, hidden_dim=1024, output_dim=512, sigma=2**0):
7
+ super(LocationEncoderCapsule, self).__init__()
8
+ rff_encoding = GaussianEncoding(sigma=sigma, input_size=input_dim, encoded_size=int(output_dim/2))
9
+ self.capsule = nn.Sequential(rff_encoding,
10
+ nn.Linear(output_dim, hidden_dim),
11
+ nn.ReLU(),
12
+ nn.Linear(hidden_dim, hidden_dim),
13
+ nn.ReLU(),
14
+ nn.Linear(hidden_dim, hidden_dim),
15
+ nn.ReLU())
16
+ self.head = nn.Sequential(nn.Linear(hidden_dim, output_dim))
17
+
18
+ def forward(self, x):
19
+ x = self.capsule(x)
20
+ x = self.head(x)
21
+ return x
22
+
23
+ class RFFMLP(nn.Module):
24
+ """Multi-layer perceptron (MLP) with batch normalization and ReLU activation."""
25
+ def __init__(self, input_dim=2, hidden_dim=1024, output_dim=512, sigma=[2**0, 2**4, 2**8], hparams=None, device='cuda'):
26
+ super(RFFMLP, self).__init__()
27
+ self.num_hierarchies = len(sigma)
28
+ self.device = device
29
+
30
+ for i, s in enumerate(sigma):
31
+ self.add_module('LocEnc' + str(i), LocationEncoderCapsule(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, sigma=s))
32
+
33
+ def forward(self, input):
34
+ location_features = torch.zeros(input.shape[0], 512).to(self.device)
35
+
36
+ for i in range(self.num_hierarchies):
37
+ location_features += self._modules['LocEnc' + str(i)](input)
38
+ return location_features
src/g3/nn/siren.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+ # helpers
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def cast_tuple(val, repeat = 1):
13
+ return val if isinstance(val, tuple) else ((val,) * repeat)
14
+
15
+ # sin activation
16
+
17
+ class Sine(nn.Module):
18
+ def __init__(self, w0 = 1.):
19
+ super().__init__()
20
+ self.w0 = w0
21
+ def forward(self, x):
22
+ return torch.sin(self.w0 * x)
23
+
24
+ # siren layer
25
+
26
+ class Siren(nn.Module):
27
+ def __init__(self, input_dim, output_dim, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None, dropout = False):
28
+ super().__init__()
29
+ self.input_dim = input_dim
30
+ self.is_first = is_first
31
+ self.output_dim = output_dim
32
+ self.dropout = dropout
33
+
34
+ weight = torch.zeros(output_dim, input_dim)
35
+ bias = torch.zeros(output_dim) if use_bias else None
36
+ self.init_(weight, bias, c = c, w0 = w0)
37
+
38
+ self.weight = nn.Parameter(weight)
39
+ self.bias = nn.Parameter(bias) if use_bias else None
40
+ self.activation = Sine(w0) if activation is None else activation
41
+
42
+ def init_(self, weight, bias, c, w0):
43
+ dim = self.input_dim
44
+
45
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
46
+ weight.uniform_(-w_std, w_std)
47
+
48
+ if exists(bias):
49
+ bias.uniform_(-w_std, w_std)
50
+
51
+ def forward(self, x):
52
+ out = F.linear(x, self.weight, self.bias)
53
+ if self.dropout:
54
+ out = F.dropout(out, training=self.training)
55
+ out = self.activation(out)
56
+ return out
57
+
58
+ # siren network
59
+
60
+ class SirenNet(nn.Module):
61
+ def __init__(self, input_dim = 512, hidden_dim = 1024, output_dim = 512, num_layers = 3, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None, degreeinput = False, dropout = False, hparams=None, device='cuda'):
62
+ super().__init__()
63
+ self.num_layers = num_layers
64
+ self.hidden_dim = hidden_dim
65
+ self.degreeinput = degreeinput
66
+ self.device = device
67
+
68
+ self.layers = nn.ModuleList([])
69
+ for ind in range(num_layers):
70
+ is_first = ind == 0
71
+ layer_w0 = w0_initial if is_first else w0
72
+ layer_input_dim = input_dim if is_first else hidden_dim
73
+
74
+ self.layers.append(Siren(
75
+ input_dim = layer_input_dim,
76
+ output_dim = hidden_dim,
77
+ w0 = layer_w0,
78
+ use_bias = use_bias,
79
+ is_first = is_first,
80
+ dropout = dropout
81
+ ))
82
+
83
+ final_activation = nn.Identity() if not exists(final_activation) else final_activation
84
+ self.last_layer = Siren(input_dim = hidden_dim, output_dim = output_dim, w0 = w0, use_bias = use_bias, activation = final_activation, dropout = False)
85
+
86
+ def forward(self, x, mods = None):
87
+
88
+ # do some normalization to bring degrees in a -pi to pi range
89
+ if self.degreeinput:
90
+ x = torch.deg2rad(x) - torch.pi
91
+
92
+ mods = cast_tuple(mods, self.num_layers)
93
+
94
+ for layer, mod in zip(self.layers, mods):
95
+ x = layer(x)
96
+
97
+ if exists(mod):
98
+ x *= rearrange(mod, 'd -> () d')
99
+
100
+ return self.last_layer(x)
src/g3/pe/projection.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ import itertools
6
+ from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+ from pyproj import Proj, Transformer
9
+
10
+ SF = 66.50336
11
+
12
+ class Projection(nn.Module):
13
+ def __init__(self, projection="mercator", hparams=None, device='cuda'):
14
+ super(Projection, self).__init__()
15
+ self.device = device
16
+ self.projection = projection.lower()
17
+
18
+ proj_wgs84 = Proj('epsg:4326')
19
+
20
+ if self.projection == "mercator":
21
+ proj_target = Proj('epsg:3857')
22
+ self.normalizer = 20037508.3427892
23
+ self.embedding_dim = [2]
24
+ elif self.projection == "eep":
25
+ proj_target = Proj('epsg:8857')
26
+ self.normalizer = 180/SF
27
+ self.embedding_dim = [2]
28
+ elif self.projection == "ecef":
29
+ proj_target = Proj('epsg:4978')
30
+ self.normalizer = 6378137.0 # radius of Earth, not exact for ECEF but usable
31
+ self.embedding_dim = [3]
32
+ else:
33
+ raise ValueError(f"Unsupported projection: {self.projection}")
34
+
35
+ self.transformer = Transformer.from_proj(proj_wgs84, proj_target, always_xy=True)
36
+
37
+ def forward(self, input):
38
+ lat = input[:, 0].float().detach().cpu().numpy()
39
+ lon = input[:, 1].float().detach().cpu().numpy()
40
+ # lon (batch), lat (batch)
41
+
42
+ # Shape: (batch, 2) or (batch, 3) depending on projection
43
+ if self.projection == "ecef":
44
+ alt = np.zeros_like(lat)
45
+ projected = self.transformer.transform(lon, lat, alt)
46
+ location = list(zip(*projected)) # X, Y, Z
47
+ location = torch.Tensor(location).to(self.device)
48
+ else:
49
+ projected = self.transformer.transform(lon, lat)
50
+ location = [[y, x] for x, y in zip(*projected)]
51
+ location = torch.Tensor(location).to(self.device)
52
+
53
+ location = location / self.normalizer
54
+ return location
src/g3/pe/projection_rff.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ import itertools
6
+ from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+ from ..rff.layers import GaussianEncoding
9
+ from pyproj import Proj, Transformer
10
+
11
+ SF = 66.50336
12
+
13
+ class ProjectionRFF(nn.Module):
14
+ def __init__(self, projection="ecef", sigma=[2**0, 2**4, 2**8], hparams=None, device='cuda'):
15
+ super(ProjectionRFF, self).__init__()
16
+ self.device = device
17
+ self.sigma = sigma
18
+ self.num_hierarchies = len(self.sigma)
19
+ self.projection = projection.lower()
20
+ self.embedding_dim = [512] * self.num_hierarchies
21
+
22
+ proj_wgs84 = Proj('epsg:4326')
23
+ if self.projection == "mercator":
24
+ proj_target = Proj('epsg:3857')
25
+ input_dim = 2
26
+ self.normalizer = 20037508.3427892
27
+ elif self.projection == "eep":
28
+ proj_target = Proj('epsg:8857')
29
+ input_dim = 2
30
+ self.normalizer = 180/SF
31
+ elif self.projection == "ecef":
32
+ proj_target = Proj('epsg:4978')
33
+ input_dim = 3
34
+ self.normalizer = 6378137.0 # radius of Earth, not exact for ECEF but usable
35
+ else:
36
+ raise ValueError(f"Unsupported projection: {self.projection}")
37
+
38
+ self.transformer = Transformer.from_proj(proj_wgs84, proj_target, always_xy=True)
39
+ for i, s in enumerate(self.sigma):
40
+ self.add_module('LocEnc' + str(i), GaussianEncoding(sigma=s, input_size=input_dim, encoded_size=256))
41
+
42
+ def forward(self, input):
43
+ lat = input[:, 0].float().detach().cpu().numpy()
44
+ lon = input[:, 1].float().detach().cpu().numpy()
45
+ # lon (batch), lat (batch)
46
+
47
+ # Shape: (batch, 2) or (batch, 3) depending on projection
48
+ if self.projection == "ecef":
49
+ alt = np.zeros_like(lat)
50
+ projected = self.transformer.transform(lon, lat, alt)
51
+ location = list(zip(*projected)) # X, Y, Z
52
+ location = torch.Tensor(location).to(self.device)
53
+ else:
54
+ projected = self.transformer.transform(lon, lat)
55
+ location = [[y, x] for x, y in zip(*projected)]
56
+ location = torch.Tensor(location).to(self.device)
57
+
58
+ location = location / self.normalizer
59
+ out = []
60
+
61
+ for i in range(self.num_hierarchies):
62
+ out.append(self._modules['LocEnc' + str(i)](location))
63
+
64
+ location_features = torch.stack(out, dim=0) # (hierarchies, batch, 512)
65
+ return location_features
src/g3/pe/spherical_harmonics.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .spherical_harmonics_ylm import SH as SH_analytic
4
+ from .spherical_harmonics_closed_form import SH as SH_closed_form
5
+
6
+ class SphericalHarmonics(nn.Module):
7
+ def __init__(self, legendre_polys: int = 20, harmonics_calculation="analytic", hparams=None, device='cuda'):
8
+ """
9
+ legendre_polys: determines the number of legendre polynomials.
10
+ more polynomials lead more fine-grained resolutions
11
+ calculation of spherical harmonics:
12
+ analytic uses pre-computed equations. This is exact, but works only up to degree 50,
13
+ closed-form uses one equation but is computationally slower (especially for high degrees)
14
+ """
15
+ super(SphericalHarmonics, self).__init__()
16
+ self.device = device
17
+ self.L, self.M = int(legendre_polys), int(legendre_polys)
18
+ self.embedding_dim = [self.L * self.M]
19
+
20
+ if harmonics_calculation == "closed-form":
21
+ self.SH = SH_closed_form
22
+ elif harmonics_calculation == "analytic":
23
+ self.SH = SH_analytic
24
+
25
+ def forward(self, lonlat):
26
+ lon, lat = lonlat[:, 0], lonlat[:, 1] # lon: (batch), lat: (batch)
27
+
28
+ # convert degree to rad
29
+ phi = torch.deg2rad(lon + 180)
30
+ theta = torch.deg2rad(lat + 90)
31
+
32
+ Y = [] # (L * L, batch)
33
+ for l in range(self.L):
34
+ for m in range(-l, l + 1):
35
+ y = self.SH(m, l, phi, theta)
36
+ if isinstance(y, float):
37
+ y = y * torch.ones_like(phi)
38
+ Y.append(y)
39
+
40
+ return torch.stack(Y,dim=-1).detach() # (batch, L * L)
src/g3/pe/spherical_harmonics_closed_form.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ ####################### Spherical Harmonics utilities ########################
5
+ # Code copied from https://github.com/BachiLi/redner/blob/master/pyredner/utils.py
6
+ # Code adapted from "Spherical Harmonic Lighting: The Gritty Details", Robin Green
7
+ # http://silviojemma.com/public/papers/lighting/spherical-harmonic-lighting.pdf
8
+ def associated_legendre_polynomial(l, m, x):
9
+ pmm = torch.ones_like(x)
10
+ if m > 0:
11
+ somx2 = torch.sqrt((1 - x) * (1 + x))
12
+ fact = 1.0
13
+ for i in range(1, m + 1):
14
+ pmm = pmm * (-fact) * somx2
15
+ fact += 2.0
16
+ if l == m:
17
+ return pmm
18
+ pmmp1 = x * (2.0 * m + 1.0) * pmm
19
+ if l == m + 1:
20
+ return pmmp1
21
+ pll = torch.zeros_like(x)
22
+ for ll in range(m + 2, l + 1):
23
+ pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
24
+ pmm = pmmp1
25
+ pmmp1 = pll
26
+ return pll
27
+
28
+ def SH_renormalization(l, m):
29
+ return math.sqrt((2.0 * l + 1.0) * math.factorial(l - m) / \
30
+ (4 * math.pi * math.factorial(l + m)))
31
+
32
+ def SH(m, l, phi, theta):
33
+ if m == 0:
34
+ return SH_renormalization(l, m) * associated_legendre_polynomial(l, m, torch.cos(theta))
35
+ elif m > 0:
36
+ return math.sqrt(2.0) * SH_renormalization(l, m) * \
37
+ torch.cos(m * phi) * associated_legendre_polynomial(l, m, torch.cos(theta))
38
+ else:
39
+ return math.sqrt(2.0) * SH_renormalization(l, -m) * \
40
+ torch.sin(-m * phi) * associated_legendre_polynomial(l, -m, torch.cos(theta))
src/g3/pe/spherical_harmonics_generate_ylms.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This function prints the source code for spherical_harmonics_ylms.py to console
3
+
4
+ spherical_harmonics pre-computes the analytical solutions to each real spherical harmonic with sympy
5
+ the script contains different functions for different degrees l and orders m
6
+
7
+ Marc Russwurm
8
+ """
9
+ from datetime import datetime
10
+ import sys
11
+
12
+ from sympy import assoc_legendre
13
+ from sympy import cos, sin, sqrt, pi, factorial, Abs
14
+ from sympy import Symbol
15
+
16
+ theta = Symbol("theta")
17
+ phi = Symbol("phi")
18
+
19
+ def calc_ylm(l, m):
20
+ """
21
+ see last equation of https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
22
+ """
23
+ if m < 0:
24
+ Plm = assoc_legendre(l, Abs(m), cos(theta))
25
+ Plm_bar = sqrt(((2 * l + 1) / (4 * pi)) * (factorial(l - Abs(m)) / factorial(l + Abs(m)))) * Plm
26
+
27
+ Ylm = (-1)**m * sqrt(2) * Plm_bar * sin(Abs(m) * phi)
28
+ elif m == 0:
29
+ Ylm = sqrt((2*l + 1) / (4 * pi)) * assoc_legendre(l, m, cos(theta))
30
+ else: # m > 0
31
+ Plm = assoc_legendre(l, m, cos(theta))
32
+ Plm_bar = sqrt(((2 * l + 1) / (4 * pi)) * (factorial(l - m) / factorial(l + m))) * Plm
33
+
34
+ Ylm = (-1)**m * sqrt(2) * Plm_bar * cos(m * phi)
35
+ return Ylm
36
+
37
+ def print_function(l, m):
38
+ fname = f"Yl{l}_m{m}".replace("-", "_minus_")
39
+ print()
40
+ print("@torch.jit.script")
41
+ print(f"def {fname}(theta, phi):")
42
+ print(" return " + str(calc_ylm(l, m).evalf()))
43
+
44
+ # max number of Legendre Polynomials
45
+ L = 101
46
+
47
+ head = """\"\"\"
48
+ analytic expressions of spherical harmonics generated with sympy file
49
+ Marc Russwurm generated """ + str(datetime.date(datetime.now())) + """
50
+
51
+ run
52
+ python """ + sys.argv[0] + """ > spherical_harmonics_ylm.py
53
+
54
+ to generate the source code
55
+ \"\"\"
56
+
57
+ import torch
58
+ from torch import cos, sin
59
+
60
+ def get_SH(m,l):
61
+ fname = f"Yl{l}_m{m}".replace("-","_minus_")
62
+ return globals()[fname]
63
+
64
+ def SH(m, l, phi, theta):
65
+ Ylm = get_SH(m,l)
66
+ return Ylm(theta, phi)
67
+ """
68
+ print(head)
69
+ print()
70
+
71
+ for l in range(L):
72
+ for m in range(-l,l+1):
73
+ print_function(l,m)
src/g3/pe/spherical_harmonics_ylm.py ADDED
The diff for this file is too large to render. See raw diff
 
src/g3/rff/functional.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from torch import Tensor
5
+
6
+
7
+ def sample_b(sigma: float, size: tuple) -> Tensor:
8
+ r"""Matrix of size :attr:`size` sampled from from :math:`\mathcal{N}(0, \sigma^2)`
9
+
10
+ Args:
11
+ sigma (float): standard deviation
12
+ size (tuple): size of the matrix sampled
13
+
14
+ See :class:`~rff.layers.GaussianEncoding` for more details
15
+ """
16
+ return torch.randn(size) * sigma
17
+
18
+
19
+ @torch.jit.script
20
+ def gaussian_encoding(
21
+ v: Tensor,
22
+ b: Tensor) -> Tensor:
23
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})`
24
+
25
+ Args:
26
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
27
+ b (Tensor): projection matrix of shape :math:`(\text{encoded_layer_size}, \text{input_size})`
28
+
29
+ Returns:
30
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{encoded_layer_size})`
31
+
32
+ See :class:`~rff.layers.GaussianEncoding` for more details.
33
+ """
34
+ vp = 2 * np.pi * v @ b.T
35
+ return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
36
+
37
+
38
+ @torch.jit.script
39
+ def basic_encoding(
40
+ v: Tensor) -> Tensor:
41
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})`
42
+
43
+ Args:
44
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
45
+
46
+ Returns:
47
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})`
48
+
49
+ See :class:`~rff.layers.BasicEncoding` for more details.
50
+ """
51
+ vp = 2 * np.pi * v
52
+ return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
53
+
54
+
55
+ @torch.jit.script
56
+ def positional_encoding(
57
+ v: Tensor,
58
+ sigma: float,
59
+ m: int) -> Tensor:
60
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
61
+ where :math:`j \in \{0, \dots, m-1\}`
62
+
63
+ Args:
64
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
65
+ sigma (float): constant chosen based upon the domain of :attr:`v`
66
+ m (int): [description]
67
+
68
+ Returns:
69
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
70
+
71
+ See :class:`~rff.layers.PositionalEncoding` for more details.
72
+ """
73
+ j = torch.arange(m, device=v.device)
74
+ coeffs = 2 * np.pi * sigma ** (j / m)
75
+ vp = coeffs * torch.unsqueeze(v, -1)
76
+ vp_cat = torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
77
+ return vp_cat.flatten(-2, -1)
src/g3/rff/layers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from typing import Optional
4
+ from torch import Tensor
5
+ from . import functional
6
+
7
+ class GaussianEncoding(nn.Module):
8
+ """Layer for mapping coordinates using random Fourier features"""
9
+
10
+ def __init__(self, sigma: Optional[float] = None,
11
+ input_size: Optional[float] = None,
12
+ encoded_size: Optional[float] = None,
13
+ b: Optional[Tensor] = None):
14
+ r"""
15
+ Args:
16
+ sigma (Optional[float]): standard deviation
17
+ input_size (Optional[float]): the number of input dimensions
18
+ encoded_size (Optional[float]): the number of dimensions the `b` matrix maps to
19
+ b (Optional[Tensor], optional): Optionally specify a :attr:`b` matrix already sampled
20
+ Raises:
21
+ ValueError:
22
+ If :attr:`b` is provided and one of :attr:`sigma`, :attr:`input_size`,
23
+ or :attr:`encoded_size` is provided. If :attr:`b` is not provided and one of
24
+ :attr:`sigma`, :attr:`input_size`, or :attr:`encoded_size` is not provided.
25
+ """
26
+ super().__init__()
27
+ if b is None:
28
+ if sigma is None or input_size is None or encoded_size is None:
29
+ raise ValueError(
30
+ 'Arguments "sigma," "input_size," and "encoded_size" are required.')
31
+
32
+ b = functional.sample_b(sigma, (encoded_size, input_size))
33
+ elif sigma is not None or input_size is not None or encoded_size is not None:
34
+ raise ValueError('Only specify the "b" argument when using it.')
35
+ self.b = nn.parameter.Parameter(b, requires_grad=False)
36
+
37
+ def forward(self, v: Tensor) -> Tensor:
38
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})`
39
+
40
+ Args:
41
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
42
+
43
+ Returns:
44
+ Tensor: Tensor mapping using random fourier features of shape :math:`(N, *, 2 \cdot \text{encoded_size})`
45
+ """
46
+ return functional.gaussian_encoding(v, self.b)
47
+
48
+
49
+ class BasicEncoding(nn.Module):
50
+ """Layer for mapping coordinates using the basic encoding"""
51
+
52
+ def forward(self, v: Tensor) -> Tensor:
53
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})`
54
+
55
+ Args:
56
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
57
+
58
+ Returns:
59
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})`
60
+ """
61
+ return functional.basic_encoding(v)
62
+
63
+
64
+ class PositionalEncoding(nn.Module):
65
+ """Layer for mapping coordinates using the positional encoding"""
66
+
67
+ def __init__(self, sigma: float, m: int):
68
+ r"""
69
+ Args:
70
+ sigma (float): frequency constant
71
+ m (int): number of frequencies to map to
72
+ """
73
+ super().__init__()
74
+ self.sigma = sigma
75
+ self.m = m
76
+
77
+ def forward(self, v: Tensor) -> Tensor:
78
+ r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
79
+
80
+ Args:
81
+ v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
82
+
83
+ Returns:
84
+ Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
85
+ """
86
+ return functional.positional_encoding(v, self.sigma, self.m)
src/g3_batch_prediction.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+ import time
7
+ import numpy as np
8
+ import torch
9
+ import yaml
10
+ from google import genai
11
+ from google.genai import types
12
+ from pydantic import ValidationError
13
+ from tqdm.asyncio import tqdm as atqdm
14
+
15
+ from .data_processor import DataProcessor
16
+ from .g3.G3 import G3
17
+ from .prompt import (
18
+ Evidence,
19
+ GPSPrediction,
20
+ LocationPrediction,
21
+ diversification_prompt,
22
+ location_prompt,
23
+ verification_prompt,
24
+ )
25
+ from .utils import (
26
+ calculate_similarity_scores,
27
+ extract_and_parse_json,
28
+ get_gps_from_location,
29
+ handle_async_api_call_with_retry,
30
+ image_to_base64,
31
+ )
32
+
33
+ logger = logging.getLogger("uvicorn.error")
34
+
35
+
36
+ class G3BatchPredictor:
37
+ """
38
+ Batch prediction class for processing all images and videos in a directory.
39
+
40
+ This class:
41
+ 1. Preprocesses all images and videos in a directory.
42
+ 2. Extracts keyframes from videos and combines them with images.
43
+ 3. Passes all keyframes and images to the Gemini model for prediction.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ device: str = "cuda",
49
+ input_dir: str = "data/input_data",
50
+ prompt_dir: str = "data/prompt_data",
51
+ cache_dir: str = "data/cache",
52
+ index_path: str = "data/index/G3.index",
53
+ hparams_path: str = "g3/hparams.yaml",
54
+ database_csv_path: str = "data/dataset/mp16/MP16_Pro_filtered.csv",
55
+ checkpoint_path: str = "data/checkpoints/mercator_finetune_weight.pth",
56
+ ):
57
+ """
58
+ Initialize the BatchKeyframePredictor.
59
+
60
+ Args:
61
+ checkpoint_path (str): Path to G3 model checkpoint
62
+ device (str): Device to run model on ("cuda" or "cpu")
63
+ index_path (str): Path to FAISS index for RAG (required)
64
+ """
65
+ self.device = torch.device(device)
66
+ self.base_path = Path(__file__).parent
67
+ self.checkpoint_path = self.base_path / checkpoint_path
68
+
69
+ self.input_dir = self.base_path / input_dir
70
+ self.prompt_dir = self.base_path / prompt_dir
71
+ self.cache_dir = self.base_path / cache_dir
72
+ self.image_dir = self.prompt_dir / "images"
73
+ self.audio_dir = self.prompt_dir / "audio"
74
+
75
+ os.makedirs(self.input_dir, exist_ok=True)
76
+ os.makedirs(self.prompt_dir, exist_ok=True)
77
+ os.makedirs(self.cache_dir, exist_ok=True)
78
+ os.makedirs(self.image_dir, exist_ok=True)
79
+ os.makedirs(self.audio_dir, exist_ok=True)
80
+
81
+ # Initialize G3 model
82
+ hparams = yaml.safe_load(open(self.base_path / hparams_path, "r"))
83
+ pe = "projection_mercator"
84
+ nn = "rffmlp"
85
+
86
+ self.model = G3(
87
+ device=device,
88
+ positional_encoding_type=pe,
89
+ neural_network_type=nn,
90
+ hparams=hparams[f"{pe}_{nn}"],
91
+ )
92
+ self.__load_checkpoint()
93
+
94
+ self.data_processor = DataProcessor(
95
+ model=self.model,
96
+ input_dir=self.input_dir,
97
+ prompt_dir=self.prompt_dir,
98
+ cache_dir=self.cache_dir,
99
+ image_dir=self.image_dir,
100
+ audio_dir=self.audio_dir,
101
+ index_path=self.base_path / index_path,
102
+ database_csv_path=self.base_path / database_csv_path,
103
+ device=self.device,
104
+ )
105
+
106
+ self.image_extension = {
107
+ ".jpg",
108
+ ".jpeg",
109
+ ".png",
110
+ ".bmp",
111
+ ".tiff",
112
+ ".tif",
113
+ ".webp",
114
+ }
115
+ self.video_extension = {
116
+ ".mp4",
117
+ ".avi",
118
+ ".mov",
119
+ ".mkv",
120
+ }
121
+
122
+ def __load_checkpoint(self):
123
+ """
124
+ Load the G3 model checkpoint.
125
+ """
126
+ if not os.path.exists(self.checkpoint_path):
127
+ raise FileNotFoundError(
128
+ f"Checkpoint file not found: {self.checkpoint_path}"
129
+ )
130
+ self.model.load_state_dict(
131
+ torch.load(self.checkpoint_path, map_location=self.device)
132
+ )
133
+ self.model.to(self.device)
134
+ self.model.eval()
135
+ logger.info(
136
+ f"✅ Successfully loaded G3 model checkpoint from: {self.checkpoint_path}"
137
+ )
138
+
139
+ async def llm_predict(
140
+ self,
141
+ model_name: str = "gemini-2.5-pro",
142
+ n_search: int | None = None,
143
+ n_coords: int | None = None,
144
+ image_prediction: bool = True,
145
+ text_prediction: bool = True,
146
+ ) -> LocationPrediction:
147
+ """
148
+ Generate a prediction using the Gemini LLM with Pydantic structured output.
149
+
150
+ Args:
151
+ model_name: LLM model name to use
152
+ n_search: Number of search results to include
153
+ n_coords: Number of coordinates to include
154
+ image_prediction: Whether to use images in prediction
155
+ text_prediction: Whether to use text in prediction
156
+
157
+ Returns:
158
+ dict: Parsed prediction response
159
+ """
160
+ prompt = diversification_prompt(
161
+ prompt_dir=str(self.prompt_dir),
162
+ n_coords=n_coords,
163
+ n_search=n_search,
164
+ image_prediction=image_prediction,
165
+ text_prediction=text_prediction,
166
+ )
167
+
168
+ images = []
169
+ if image_prediction:
170
+ image_dir = self.image_dir
171
+ if not image_dir.exists():
172
+ raise ValueError(f"Image directory does not exist: {image_dir}")
173
+
174
+ for image_file in image_dir.glob("*.jpg"):
175
+ with open(image_file, "rb") as f:
176
+ image = types.Part.from_bytes(data=f.read(), mime_type="image/jpeg")
177
+ images.append(image)
178
+
179
+ client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
180
+
181
+ async def api_call():
182
+ loop = asyncio.get_event_loop()
183
+ response = await loop.run_in_executor(
184
+ None,
185
+ lambda: client.models.generate_content(
186
+ model=model_name,
187
+ contents=[*images, prompt],
188
+ config=types.GenerateContentConfig(
189
+ tools=[
190
+ types.Tool(url_context=types.UrlContext()),
191
+ ],
192
+ temperature=0.1,
193
+ top_p=0.95,
194
+ ),
195
+ ),
196
+ )
197
+
198
+ raw_text = response.text.strip() if response.text is not None else ""
199
+ parsed_json = extract_and_parse_json(raw_text)
200
+
201
+ try:
202
+ validated = LocationPrediction.model_validate(parsed_json)
203
+ return validated
204
+ except (ValidationError, ValueError):
205
+ raise ValueError("Empty or invalid LLM response")
206
+
207
+ return await handle_async_api_call_with_retry(
208
+ api_call,
209
+ fallback_result=LocationPrediction(
210
+ latitude=0.0, longitude=0.0, location="", evidence=[]
211
+ ),
212
+ error_context=f"LLM prediction with {model_name}",
213
+ )
214
+
215
+ async def diversification_predict(
216
+ self,
217
+ model_name: str = "gemini-2.5-flash",
218
+ image_prediction: bool = True,
219
+ text_prediction: bool = True,
220
+ ) -> LocationPrediction:
221
+ """
222
+ Diversification prediction without preprocessing (assumes preprocessing already done).
223
+ Runs different sample sizes in parallel for faster execution.
224
+
225
+ Args:
226
+ model_name (str): LLM model name to use
227
+ image_prediction (bool): Whether to use images in prediction
228
+ text_prediction (bool): Whether to use text in prediction
229
+
230
+ Returns:
231
+ dict: Best prediction with latitude, longitude, location, reason, and metadata
232
+ """
233
+
234
+ # Function to try a specific sample size with retry logic
235
+ async def try_sample_size(num_sample):
236
+ while True:
237
+ prediction = await self.llm_predict(
238
+ model_name=model_name,
239
+ n_search=num_sample,
240
+ n_coords=num_sample,
241
+ image_prediction=image_prediction,
242
+ text_prediction=text_prediction,
243
+ )
244
+
245
+ if prediction:
246
+ coords = (prediction.latitude, prediction.longitude)
247
+ return (num_sample, coords, prediction)
248
+ else:
249
+ logger.info(
250
+ f"Invalid or empty prediction format with {num_sample} samples, retrying..."
251
+ )
252
+
253
+ # Run all sample sizes in parallel
254
+ num_samples = [10, 15, 20]
255
+ logger.info(
256
+ f"🚀 Running {len(num_samples)} sample sizes in parallel: {num_samples}"
257
+ )
258
+
259
+ tasks = [try_sample_size(num_sample) for num_sample in num_samples]
260
+
261
+ class LW:
262
+ def write(self, msg: str) -> int:
263
+ logger.info(msg)
264
+ return len(msg)
265
+
266
+ def flush(self):
267
+ pass
268
+
269
+ results = await atqdm.gather(
270
+ *tasks,
271
+ desc="🔄 Running diversification predictions",
272
+ file=LW(),
273
+ )
274
+
275
+ # Build predictions dictionary from parallel results
276
+ predictions_dict = {}
277
+ for num_sample, coords, prediction in results:
278
+ predictions_dict[coords] = prediction
279
+ logger.info(f"✅ Collected prediction with {num_sample} samples: {coords}")
280
+
281
+ # Convert predictions to coordinate list for similarity scoring
282
+ predicted_coords = list(predictions_dict.keys())
283
+ logger.info(f"Predicted coordinates: {predicted_coords}")
284
+
285
+ if not predicted_coords:
286
+ raise ValueError("No valid predictions obtained from any sample size")
287
+
288
+ # Calculate similarity scores
289
+ avg_similarities = calculate_similarity_scores(
290
+ model=self.model,
291
+ device=self.device,
292
+ predicted_coords=predicted_coords,
293
+ image_dir=self.image_dir,
294
+ )
295
+
296
+ # Find best prediction
297
+ best_idx = np.argmax(avg_similarities)
298
+ best_coords = predicted_coords[best_idx]
299
+ best_prediction = predictions_dict[best_coords]
300
+
301
+ logger.info(f"🎯 Best prediction selected: {best_coords}")
302
+ logger.info(f" Similarity scores: {avg_similarities}")
303
+ logger.info(f" Best index: {best_idx}")
304
+
305
+ # print(json.dumps(best_prediction, indent=2)) # Commented out verbose output
306
+
307
+ return best_prediction
308
+
309
+ async def location_predict(
310
+ self,
311
+ model_name: str = "gemini-2.5-flash",
312
+ location: str = "specified location",
313
+ ) -> GPSPrediction:
314
+ """
315
+ Generate a location-based prediction using the Gemini LLM with centralized retry logic.
316
+
317
+ Args:
318
+ model_name (str): LLM model name to use
319
+ location (str): Location to use in the prompt
320
+
321
+ Returns:
322
+ dict: Parsed JSON prediction response
323
+ """
324
+ if not location:
325
+ raise ValueError("Location must be specified for location-based prediction")
326
+
327
+ lat, lon = get_gps_from_location(location)
328
+ if lat is not None and lon is not None:
329
+ logger.info(
330
+ f"Using GPS coordinates for location '{location}': ({lat}, {lon})"
331
+ )
332
+ return GPSPrediction(
333
+ latitude=lat, longitude=lon, analysis="", references=[]
334
+ )
335
+ else:
336
+ prompt = location_prompt(location)
337
+ client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
338
+
339
+ async def api_call():
340
+ # Run the synchronous API call in a thread executor to make it truly async
341
+ loop = asyncio.get_event_loop()
342
+ response = await loop.run_in_executor(
343
+ None,
344
+ lambda: client.models.generate_content(
345
+ model=model_name,
346
+ contents=[prompt],
347
+ config=types.GenerateContentConfig(
348
+ tools=[
349
+ types.Tool(google_search=types.GoogleSearch()),
350
+ ],
351
+ temperature=0.1,
352
+ top_p=0.95,
353
+ ),
354
+ ),
355
+ )
356
+
357
+ raw_text = response.text.strip() if response.text is not None else ""
358
+ parsed_json = extract_and_parse_json(raw_text)
359
+
360
+ try:
361
+ validated = GPSPrediction.model_validate(parsed_json)
362
+ return validated
363
+ except (ValidationError, ValueError):
364
+ raise ValueError("Empty or invalid LLM response")
365
+
366
+ return await handle_async_api_call_with_retry(
367
+ api_call,
368
+ fallback_result=GPSPrediction(
369
+ latitude=0.0, longitude=0.0, analysis="", references=[]
370
+ ),
371
+ error_context=f"Location prediction for '{location}' with {model_name}",
372
+ )
373
+
374
+ async def verification_predict(
375
+ self,
376
+ prediction: LocationPrediction,
377
+ model_name: str = "gemini-2.5-flash",
378
+ image_prediction: bool = True,
379
+ text_prediction: bool = True,
380
+ ) -> LocationPrediction:
381
+ """
382
+ Generate verification prediction based on the provided prediction.
383
+
384
+ Args:
385
+ prediction (dict): Prediction dictionary with latitude, longitude, location, reason, and metadata
386
+ model_name (str): LLM model name to use for verification
387
+
388
+ Returns:
389
+ dict: Verification prediction with latitude, longitude, location, reason, and evidence
390
+ """
391
+ # Prepare verification data (now async)
392
+ satellite_image_id = await self.data_processor.prepare_location_images(
393
+ prediction=prediction.model_dump(),
394
+ image_prediction=image_prediction,
395
+ text_prediction=text_prediction,
396
+ )
397
+
398
+ image_dir = self.image_dir
399
+
400
+ images = []
401
+ if image_prediction:
402
+ if not image_dir.exists():
403
+ raise ValueError(f"Image directory does not exist: {image_dir}")
404
+
405
+ for image_file in image_dir.glob("*.jpg"):
406
+ with open(image_file, "rb") as f:
407
+ image = types.Part.from_bytes(data=f.read(), mime_type="image/jpeg")
408
+ images.append(image)
409
+
410
+ # Prepare verification prompt
411
+ prompt = verification_prompt(
412
+ satellite_image_id=satellite_image_id,
413
+ prediction=prediction.model_dump(),
414
+ prompt_dir=str(self.prompt_dir),
415
+ image_prediction=image_prediction,
416
+ text_prediction=text_prediction,
417
+ )
418
+
419
+ client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
420
+
421
+ async def api_call():
422
+ # Run the synchronous API call in a thread executor to make it truly async
423
+ loop = asyncio.get_event_loop()
424
+ response = await loop.run_in_executor(
425
+ None,
426
+ lambda: client.models.generate_content(
427
+ model=model_name,
428
+ contents=[*images, prompt],
429
+ config=types.GenerateContentConfig(
430
+ tools=[
431
+ types.Tool(url_context=types.UrlContext()),
432
+ ],
433
+ temperature=0.1,
434
+ top_p=0.95,
435
+ ),
436
+ ),
437
+ )
438
+
439
+ raw_text = response.text.strip() if response.text is not None else ""
440
+ parsed_json = extract_and_parse_json(raw_text)
441
+
442
+ try:
443
+ validated = LocationPrediction.model_validate(parsed_json)
444
+ return validated
445
+ except (ValidationError, ValueError):
446
+ raise ValueError("Empty or invalid LLM response")
447
+
448
+ return await handle_async_api_call_with_retry(
449
+ api_call,
450
+ fallback_result=LocationPrediction(
451
+ latitude=0.0, longitude=0.0, location="", evidence=[]
452
+ ),
453
+ error_context=f"Verification prediction with {model_name}",
454
+ )
455
+
456
+ async def predict(
457
+ self,
458
+ model_name: str = "gemini-2.5-flash",
459
+ image_prediction: bool = True,
460
+ text_prediction: bool = True,
461
+ ) -> LocationPrediction:
462
+ """
463
+ Complete prediction pipeline without preprocessing (assumes preprocessing already done).
464
+ Used for parallel execution where preprocessing is done once beforehand.
465
+ All major steps run in parallel for maximum speed.
466
+
467
+ Args:
468
+ model_name (str): LLM model name to use
469
+ image_prediction (bool): Whether to use images in prediction
470
+ text_prediction (bool): Whether to use text in prediction
471
+
472
+ Returns:
473
+ dict: Final prediction with latitude, longitude, location, reason, and evidence
474
+ """
475
+ logger.info(
476
+ f"🚀 Starting multi-modal prediction pipeline with model: {model_name}"
477
+ )
478
+ await self.data_processor.preprocess_input_data()
479
+ # Step 1: Run diversification prediction (this is already parallel internally)
480
+ logger.info(
481
+ f"\n🔄 Running diversification prediction for Image={image_prediction}, Text={text_prediction}..."
482
+ )
483
+ diversification_result = await self.diversification_predict(
484
+ model_name=model_name,
485
+ image_prediction=image_prediction,
486
+ text_prediction=text_prediction,
487
+ )
488
+
489
+ # Step 2: Run location prediction
490
+ location_prediction = await self.location_predict(
491
+ model_name=model_name, location=diversification_result.location
492
+ )
493
+
494
+ logger.info("✅ Location prediction completed:")
495
+
496
+ # Step 3: Update coordinates and evidence from location prediction
497
+ result = diversification_result.model_copy()
498
+ result.longitude = location_prediction.longitude
499
+ result.latitude = location_prediction.latitude
500
+
501
+ # Step 4: Normalize and append location evidence
502
+ if location_prediction.analysis and location_prediction.references:
503
+ location_evidence = Evidence(
504
+ analysis=location_prediction.analysis,
505
+ references=location_prediction.references,
506
+ )
507
+ else:
508
+ location_evidence = Evidence(
509
+ analysis="No specific location analysis provided.",
510
+ references=[],
511
+ )
512
+
513
+ # Append to result evidence
514
+ result.evidence.append(location_evidence)
515
+
516
+ # Step 5: Run verification prediction
517
+ logger.info(
518
+ f"\n🔄 Running verification prediction for Image={image_prediction}, Text={text_prediction}..."
519
+ )
520
+ result = await self.verification_predict(
521
+ prediction=result,
522
+ model_name=model_name,
523
+ image_prediction=image_prediction,
524
+ text_prediction=text_prediction,
525
+ )
526
+
527
+ logger.info(
528
+ f"\n🎯 Final prediction for Image={image_prediction}, Text={text_prediction}:"
529
+ )
530
+ # print(json.dumps(result, indent=2)) # Commented out verbose output
531
+
532
+ return result
533
+
534
+ def get_response(self, prediction: LocationPrediction) -> LocationPrediction:
535
+ """
536
+ Convert image references in the prediction to base64 strings.
537
+ """
538
+ for evidence in prediction.evidence:
539
+ for i, ref in enumerate(evidence.references):
540
+ if ref.startswith("image"):
541
+ evidence.references[i] = image_to_base64(self.image_dir / ref)
542
+ return prediction
543
+
544
+ def get_transcript(self) -> str:
545
+ """
546
+ Get the transcript from the transcript files in the audio directory.
547
+ """
548
+ transcript = ""
549
+ for transcript_file in self.audio_dir.glob("*.txt"):
550
+ with open(transcript_file, "r", encoding="utf-8") as f:
551
+ logger.info(f"Reading transcript from {transcript_file.name}")
552
+ transcript_data = f.read().strip()
553
+ if transcript_data:
554
+ transcript += f"Transcript for {transcript_file.name}\n"
555
+ transcript += transcript_data
556
+ return transcript
557
+
558
+ def clear_directories(self):
559
+ """
560
+ Clear the input and prompt directories.
561
+ """
562
+ delete_dirs = [self.input_dir, self.prompt_dir]
563
+ for dir_path in delete_dirs:
564
+ if os.path.exists(dir_path):
565
+ shutil.rmtree(dir_path)
566
+ logger.info(f"Deleted folder: {dir_path}")
567
+ else:
568
+ logger.info(f"Folder does not exist: {dir_path}")
src/prompt/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .factory import diversification_prompt as diversification_prompt
2
+ from .factory import location_prompt as location_prompt
3
+ from .factory import verification_prompt as verification_prompt
4
+ from .factory import Evidence as Evidence
5
+ from .factory import GPSPrediction as GPSPrediction
6
+ from .factory import LocationPrediction as LocationPrediction
src/prompt/factory.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from .template import DIVERSIFICATION_PROMPT, LOCATION_PROMPT, VERIFICATION_PROMPT
7
+
8
+
9
+ class Evidence(BaseModel):
10
+ analysis: str
11
+ references: list[str] = []
12
+
13
+
14
+ class LocationPrediction(BaseModel):
15
+ latitude: float
16
+ longitude: float
17
+ location: str
18
+ evidence: list[Evidence]
19
+
20
+
21
+ class GPSPrediction(BaseModel):
22
+ latitude: float
23
+ longitude: float
24
+ analysis: str
25
+ references: list[str]
26
+
27
+
28
+ def rag_prompt(index_search_json: str, n_coords: int | None = None) -> str:
29
+ """
30
+ Creates a formatted string with GPS coordinates for similar and dissimilar images.
31
+
32
+ Args:
33
+ candidates_gps (list[tuple]): List of (lat, lon) tuples for similar images.
34
+ reverse_gps (list[tuple]): List of (lat, lon) tuples for dissimilar images.
35
+ n_coords (int, optional): Number of coords to include from each list. Defaults to all.
36
+
37
+ Returns:
38
+ str: Formatted string with coordinates for reference.
39
+ """
40
+ if not os.path.exists(index_search_json):
41
+ return ""
42
+
43
+ with open(index_search_json, "r", encoding="utf-8") as file:
44
+ data = json.load(file)
45
+
46
+ candidates_gps = data.get("candidates_gps", [])
47
+ reverse_gps = data.get("reverse_gps", [])
48
+
49
+ if n_coords is not None:
50
+ candidates_gps = candidates_gps[: min(n_coords, len(candidates_gps))]
51
+ reverse_gps = reverse_gps[: min(n_coords, len(reverse_gps))]
52
+ else:
53
+ candidates_gps = candidates_gps
54
+ reverse_gps = reverse_gps
55
+
56
+ candidates_str = (
57
+ "[" + ", ".join(f"[{lat}, {lon}]" for (lat, lon) in candidates_gps) + "]"
58
+ )
59
+ reverse_str = "[" + ", ".join(f"[{lat}, {lon}]" for (lat, lon) in reverse_gps) + "]"
60
+ return f"For your reference, these are coordinates of some similar images: {candidates_str}, and these are coordinates of some dissimilar images: {reverse_str}."
61
+
62
+
63
+ def metadata_prompt(metadata_file_path: str) -> str:
64
+ """
65
+ Reads a metadata JSON file and returns a formatted string combining all fields.
66
+
67
+ Args:
68
+ metadata_file_path (str): Path to the metadata JSON file
69
+
70
+ Returns:
71
+ str: Formatted string with all metadata fields combined
72
+ """
73
+ if not metadata_file_path or not os.path.exists(metadata_file_path):
74
+ return ""
75
+
76
+ try:
77
+ with open(metadata_file_path, "r", encoding="utf-8") as file:
78
+ metadata = json.load(file)
79
+
80
+ if not metadata:
81
+ return ""
82
+
83
+ metadata_parts = []
84
+
85
+ if "location" in metadata and metadata["location"]:
86
+ metadata_parts.append(f"Location: {metadata['location']}")
87
+
88
+ if "violence level" in metadata and metadata["violence level"]:
89
+ metadata_parts.append(f"Violence level: {metadata['violence level']}")
90
+
91
+ if "title" in metadata and metadata["title"]:
92
+ metadata_parts.append(f"Title: {metadata['title']}")
93
+
94
+ if "social media link" in metadata and metadata["social media link"]:
95
+ metadata_parts.append(f"Social media link: {metadata['social media link']}")
96
+
97
+ if "description" in metadata and metadata["description"]:
98
+ metadata_parts.append(f"Description: {metadata['description']}")
99
+
100
+ if "category" in metadata and metadata["category"]:
101
+ metadata_parts.append(f"Category: {metadata['category']}")
102
+
103
+ if not metadata_parts:
104
+ return ""
105
+
106
+ metadata_string = "Metadata for the image is: "
107
+ return metadata_string + ". ".join(metadata_parts) + "."
108
+
109
+ except Exception:
110
+ return ""
111
+
112
+
113
+ def search_prompt(search_candidates: list[str], n_search: int | None = None) -> str:
114
+ """
115
+ Formats search candidate links into a prompt string.
116
+
117
+ Args:
118
+ search_candidates (list[str]): List of candidate URLs from image search
119
+ n_search (int): Number of results to include (default: 5)
120
+
121
+ Returns:
122
+ str: Formatted string with candidate links, each on a new line
123
+
124
+ Example:
125
+ >>> candidates = search_prompt(["https://example1.com", "https://example2.com"], n_search=3)
126
+ >>> print(candidates)
127
+ Similar image can be found in those links:
128
+ https://example1.com
129
+ https://example2.com
130
+ """
131
+
132
+ if not search_candidates or not isinstance(search_candidates, list):
133
+ return ""
134
+
135
+ EXCLUDE_DOMAINS = [
136
+ "x.com",
137
+ "twitter.com",
138
+ "linkedin.com",
139
+ "bbc.com",
140
+ "bbc.co.uk",
141
+ "instagram.com",
142
+ "tiktok.com",
143
+ ]
144
+
145
+ for domain in EXCLUDE_DOMAINS:
146
+ search_candidates = [url for url in search_candidates if domain not in url]
147
+
148
+ if n_search is not None:
149
+ search_candidates = search_candidates[: min(n_search, len(search_candidates))]
150
+
151
+ try:
152
+ prompt = "\n".join(search_candidates)
153
+ return prompt
154
+
155
+ except Exception:
156
+ return ""
157
+
158
+
159
+ def image_search_prompt(image_search_json: str, n_search: int | None = None) -> str:
160
+ """
161
+ Reads all JSON files in the base directory's image_search folder and combines links.
162
+
163
+ Args:
164
+ base_dir (str): Path to the base directory containing image search JSON files
165
+
166
+ Returns:
167
+ str: Combined search prompt string
168
+ """
169
+ pages_with_matching_images = set()
170
+ full_matching_images = set()
171
+ partial_matching_images = set()
172
+
173
+ with open(image_search_json, "r", encoding="utf-8") as file:
174
+ data_list = json.load(file)
175
+ for json_data in data_list:
176
+ if "pages_with_matching_images" in json_data:
177
+ pages_with_matching_images.update(
178
+ json_data["pages_with_matching_images"]
179
+ )
180
+ elif "full_matching_images" in json_data:
181
+ full_matching_images.update(json_data["full_matching_images"])
182
+ elif "partial_matching_images" in json_data:
183
+ partial_matching_images.update(json_data["partial_matching_images"])
184
+
185
+ if (
186
+ not pages_with_matching_images
187
+ and not full_matching_images
188
+ and not partial_matching_images
189
+ ):
190
+ return ""
191
+
192
+ prompt = "Those are pages with matching images:\n"
193
+ prompt += search_prompt(list(pages_with_matching_images), n_search=n_search)
194
+ # prompt += "\n\nThose are full matching images:\n"
195
+ # prompt += search_prompt(list(full_matching_images), n_search=n_search)
196
+ # prompt += "\n\nThose are partial matching images:\n"
197
+ # prompt += search_prompt(list(partial_matching_images), n_search=n_search)
198
+
199
+ return prompt
200
+
201
+
202
+ def search_content_prompt(search_content_json: str) -> str:
203
+ """
204
+ Reads a JSON file containing search content and returns a formatted string.
205
+
206
+ Args:
207
+ search_content_json (str): Path to the JSON file with search content
208
+
209
+ Returns:
210
+ str: Formatted string with all search content links
211
+ """
212
+ if not os.path.exists(search_content_json):
213
+ return ""
214
+
215
+ try:
216
+ with open(search_content_json, "r", encoding="utf-8") as file:
217
+ data = json.load(file)
218
+
219
+ if not data or not isinstance(data, list):
220
+ return ""
221
+
222
+ prompt = json.dumps(data, indent=2)
223
+ return prompt
224
+
225
+ except Exception:
226
+ return ""
227
+
228
+
229
+ def transcript_prompt(audio_dir: str) -> str:
230
+ """
231
+ Reads all transcript text files in the audio directory and returns a formatted string.
232
+
233
+ Args:
234
+ audio_dir (str): Path to the audio directory containing transcript files
235
+
236
+ Returns:
237
+ str: Combined transcript content formatted as a prompt
238
+ """
239
+ if not os.path.exists(audio_dir):
240
+ return ""
241
+
242
+ transcript_content = []
243
+
244
+ for txt_file in os.listdir(audio_dir):
245
+ if txt_file.endswith(".txt"):
246
+ txt_path = os.path.join(audio_dir, txt_file)
247
+ with open(txt_path, "r", encoding="utf-8") as file:
248
+ transcript_content.append(file.read().strip())
249
+
250
+ combined_transcript = "\n".join(transcript_content)
251
+ return (
252
+ f"This is the transcript of the video: {combined_transcript}"
253
+ if combined_transcript
254
+ else ""
255
+ )
256
+
257
+
258
+ def combine_prompt_data(
259
+ prompt_dir: str,
260
+ n_search: int | None = None,
261
+ n_coords: int | None = None,
262
+ image_prediction: bool = True,
263
+ text_prediction: bool = True,
264
+ ) -> str:
265
+ """
266
+ Combines all prompt data into one comprehensive prompt string.
267
+
268
+ Args:
269
+ base_dir (str): Path to the base directory
270
+ candidates_gps (list[tuple]): GPS coordinates for similar images (for RAG)
271
+ reverse_gps (list[tuple]): GPS coordinates for dissimilar images (for RAG)
272
+ n_search (int): Number of search results to include (default: 5)
273
+ n_coords (int, optional): Number of coordinates to include in RAG
274
+
275
+ Returns:
276
+ str: Combined prompt string
277
+
278
+ Example:
279
+ >>> prompt = combine_prompts(
280
+ ... base_dir="path/to/base_dir",
281
+ ... candidates_gps=[(40.7128, -74.0060)],
282
+ ... reverse_gps=[(51.5074, -0.1278)]
283
+ ... )
284
+ """
285
+
286
+ prompt_parts = []
287
+
288
+ # 1. RAG prompt (optional)
289
+ if n_coords is not None:
290
+ rag_text = rag_prompt(os.path.join(prompt_dir, "index_search.json"), n_coords)
291
+ prompt_parts.append(rag_text)
292
+
293
+ # 2. Metadata prompt
294
+ if text_prediction:
295
+ metadata_text = metadata_prompt(os.path.join(prompt_dir, "metadata.json"))
296
+ if metadata_text:
297
+ prompt_parts.append(metadata_text)
298
+
299
+ # 3. Search prompt
300
+ if image_prediction:
301
+ image_search_text = search_content_prompt(
302
+ os.path.join(prompt_dir, "image_search_content.json")
303
+ )
304
+ if image_search_text:
305
+ prompt_parts.append(image_search_text)
306
+
307
+ if text_prediction:
308
+ search_content_text = search_content_prompt(
309
+ os.path.join(prompt_dir, "text_search_content.json")
310
+ )
311
+ if search_content_text:
312
+ prompt_parts.append(search_content_text)
313
+
314
+ # 4. Transcript prompt
315
+ transcript_text = transcript_prompt(os.path.join(prompt_dir, "audio"))
316
+ if transcript_text:
317
+ prompt_parts.append(transcript_text)
318
+
319
+ # Combine all parts with double newlines for readability
320
+ combined_prompt = "\n\n".join(part for part in prompt_parts if part.strip())
321
+
322
+ return combined_prompt
323
+
324
+
325
+ def diversification_prompt(
326
+ prompt_dir: str,
327
+ n_search: int | None = None,
328
+ n_coords: int | None = None,
329
+ image_prediction: bool = True,
330
+ text_prediction: bool = True,
331
+ ) -> str:
332
+ """
333
+ Combines all prompts into one comprehensive prompt string.
334
+
335
+ Args:
336
+ base_dir (str): Path to the base directory
337
+ candidates_gps (list[tuple]): GPS coordinates for similar images (for RAG)
338
+ reverse_gps (list[tuple]): GPS coordinates for dissimilar images (for RAG)
339
+ n_search (int): Number of search results to include (default: 5)
340
+ n_coords (int, optional): Number of coordinates to include in RAG
341
+
342
+ Returns:
343
+ str: Combined prompt string
344
+
345
+ Example:
346
+ >>> prompt = combine_prompts(
347
+ ... base_dir="path/to/base_dir",
348
+ ... candidates_gps=[(40.7128, -74.0060)],
349
+ ... reverse_gps=[(51.5074, -0.1278)]
350
+ ... )
351
+ """
352
+
353
+ prompt_data = combine_prompt_data(
354
+ prompt_dir,
355
+ n_search=n_search,
356
+ n_coords=n_coords,
357
+ image_prediction=image_prediction,
358
+ text_prediction=text_prediction,
359
+ )
360
+
361
+ prompt = DIVERSIFICATION_PROMPT.strip().format(prompt_data=prompt_data)
362
+
363
+ return prompt
364
+
365
+
366
+ def location_prompt(location: str) -> str:
367
+ """
368
+ Creates a prompt string for the given location.
369
+
370
+ Args:
371
+ location (str): The location to include in the prompt.
372
+
373
+ Returns:
374
+ str: Formatted string with the location.
375
+ """
376
+ if not location:
377
+ return ""
378
+
379
+ prompt = LOCATION_PROMPT.strip()
380
+ prompt = prompt.format(location=location)
381
+
382
+ return prompt
383
+
384
+
385
+ def verification_prompt(
386
+ satellite_image_id: int,
387
+ prediction: dict,
388
+ prompt_dir: str,
389
+ n_search: int | None = None,
390
+ n_coords: int | None = None,
391
+ image_prediction: bool = True,
392
+ text_prediction: bool = True,
393
+ ) -> str:
394
+ """
395
+ Creates a verification prompt string with the provided data and prediction.
396
+
397
+ Args:
398
+ prompt_data (str): The prompt data to include.
399
+ prediction (str): The prediction to verify.
400
+
401
+ Returns:
402
+ str: Formatted verification prompt string.
403
+ """
404
+ prompt_data = combine_prompt_data(
405
+ prompt_dir,
406
+ n_search=n_search,
407
+ n_coords=n_coords,
408
+ image_prediction=image_prediction,
409
+ text_prediction=text_prediction,
410
+ )
411
+
412
+ prompt = VERIFICATION_PROMPT.strip().format(
413
+ prompt_data=prompt_data,
414
+ prediction=json.dumps(prediction, indent=2),
415
+ satellite_image_id=f"{satellite_image_id:03d}",
416
+ )
417
+
418
+ return prompt
src/prompt/fetch/content_fetch.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Literal, TypedDict
6
+
7
+ from playwright.async_api import Page, async_playwright
8
+
9
+ READABILITY_JS_URL = "https://unpkg.com/@mozilla/readability@0.4.4/Readability.js"
10
+ logger = logging.getLogger("uvicorn.error")
11
+
12
+
13
+ class PageText(TypedDict):
14
+ url: str
15
+ text: str
16
+
17
+
18
+ WaitUntil = Literal["load", "domcontentloaded", "networkidle", "commit"]
19
+
20
+
21
+ async def _inject_readability(page: Page) -> None:
22
+ is_html = await page.evaluate("() => document.documentElement.nodeName === 'HTML'")
23
+ if not is_html:
24
+ return
25
+
26
+ await page.add_script_tag(url=READABILITY_JS_URL)
27
+ await page.add_script_tag(
28
+ content="window.__readability__ = new Readability(document.cloneNode(true));"
29
+ )
30
+
31
+
32
+ async def _fetch_text(page: Page, url: str, wait_until: WaitUntil) -> str:
33
+ await page.goto(url, wait_until=wait_until)
34
+ await page.wait_for_timeout(1000)
35
+
36
+ # Attempt Readability.js parsing first
37
+ try:
38
+ await _inject_readability(page)
39
+ readability_text = await page.evaluate(
40
+ "() => window.__readability__.parse()?.textContent"
41
+ )
42
+ if readability_text:
43
+ return readability_text.strip()
44
+ except BaseException as _:
45
+ pass
46
+
47
+ # Fallback: Twitter specific logic
48
+ try:
49
+ tweet_text = await page.locator(
50
+ "article div[data-testid='tweetText']"
51
+ ).all_inner_texts()
52
+ if tweet_text:
53
+ return "\n".join(tweet_text)
54
+ except BaseException as _:
55
+ pass
56
+
57
+ # Final fallback: full body text
58
+ return await page.evaluate("() => document.body.innerText")
59
+
60
+
61
+ async def fetch_text(
62
+ url: str, headless: bool = False, wait_until: WaitUntil = "load"
63
+ ) -> PageText:
64
+ async with async_playwright() as pw:
65
+ browser = await pw.chromium.launch_persistent_context(
66
+ user_data_dir="",
67
+ channel="chrome",
68
+ headless=headless,
69
+ no_viewport=True,
70
+ )
71
+ page = await browser.new_page()
72
+ text = await _fetch_text(page, url, wait_until)
73
+ await browser.close()
74
+
75
+ return PageText(url=url, text=text)
76
+
77
+
78
+ async def fetch_texts(
79
+ urls: list[str], headless: bool = False, wait_until: WaitUntil = "load"
80
+ ) -> list[PageText | BaseException]:
81
+ async with async_playwright() as pw:
82
+ browser = await pw.chromium.launch_persistent_context(
83
+ user_data_dir="",
84
+ channel="chrome",
85
+ headless=True,
86
+ no_viewport=True,
87
+ )
88
+ # browser = await pw.chromium.launch_persistent_context(
89
+ # user_data_dir="/tmp/playwright_profile",
90
+ # headless=True,
91
+ # no_viewport=True,
92
+ # )
93
+ pages = [await browser.new_page() for _ in urls]
94
+
95
+ tasks = [_fetch_text(page, url, wait_until) for page, url in zip(pages, urls)]
96
+ results_raw = await asyncio.gather(*tasks, return_exceptions=True)
97
+ await browser.close()
98
+
99
+ results: list[PageText | BaseException] = []
100
+ for url, result in zip(urls, results_raw):
101
+ if isinstance(result, BaseException):
102
+ results.append(result)
103
+ else:
104
+ results.append(PageText(url=url, text=result))
105
+
106
+ return results
107
+
108
+
109
+ async def fetch_links_to_json(
110
+ links: list[str],
111
+ output_path: str,
112
+ headless: bool = False,
113
+ wait_until: WaitUntil = "load",
114
+ max_content_length: int = 5000,
115
+ ) -> None:
116
+ """
117
+ Fetch content from a list of links and save to a JSON file.
118
+
119
+ Args:
120
+ links: List of URLs to fetch content from
121
+ output_path: Path where the JSON file will be saved
122
+ headless: Whether to run browser in headless mode
123
+ wait_until: When to consider page loading complete
124
+ max_content_length: Maximum number of characters to keep from each page content
125
+
126
+ Returns:
127
+ None (saves results to JSON file)
128
+ """
129
+ logger.info(f"📥 Fetching content from {len(links)} links...")
130
+
131
+ # Fetch content from all links
132
+ results = await fetch_texts(links, headless=headless, wait_until=wait_until)
133
+
134
+ # Process results into the desired format
135
+ json_data = []
136
+ for i, (link, result) in enumerate(zip(links, results)):
137
+ logger.info(f" Processing {i + 1}/{len(links)}: {link}")
138
+
139
+ if isinstance(result, BaseException):
140
+ # Handle errors gracefully
141
+ json_data.append({"link": link, "content": "Fail to fetch content..."})
142
+ else:
143
+ # Successfully fetched content - apply length limit
144
+ content = result["text"]
145
+ if len(content) > max_content_length:
146
+ content = (
147
+ content[:max_content_length]
148
+ + "... [content truncated due to length limit]"
149
+ )
150
+ logger.info(
151
+ f"✂️ Content truncated from {len(result['text'])} to {max_content_length} characters"
152
+ )
153
+
154
+ json_data.append({"link": link, "content": content})
155
+
156
+ # Ensure output directory exists
157
+ output_file = Path(output_path)
158
+ output_file.parent.mkdir(parents=True, exist_ok=True)
159
+
160
+ # Save to JSON file
161
+ with open(output_file, "w", encoding="utf-8") as f:
162
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
163
+
164
+ logger.info(f"💾 Saved content from {len(links)} links to {output_path}")
165
+
166
+ # Print summary
167
+ successful = sum(
168
+ 1 for item in json_data if not item["content"].startswith("Error fetching")
169
+ )
170
+ failed = len(json_data) - successful
171
+ logger.info(f"📊 Summary: {successful} successful, {failed} failed")
src/prompt/fetch/satellite_fetch.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import httpx
4
+ from geopy import Point
5
+ from geopy.distance import distance
6
+
7
+ logger = logging.getLogger("uvicorn.error")
8
+
9
+
10
+ def meter_offsets(lat: float, lon: float, extend: float) -> tuple[float, float]:
11
+ """
12
+ Returns (lat_offset, lon_offset) in degrees for a given
13
+ center point (lat, lon) and radial distance in meters (extend).
14
+ """
15
+ origin = Point(lat, lon)
16
+ # Move north (bearing=0°) and east (bearing=90°)
17
+ north = distance(meters=extend).destination(origin, bearing=0)
18
+ east = distance(meters=extend).destination(origin, bearing=90)
19
+ return north.latitude - lat, east.longitude - lon
20
+
21
+
22
+ def fetch_satellite_image(
23
+ lat: float, lon: float, extend: float, output_path: str = "esri_sat.png"
24
+ ) -> None:
25
+ """
26
+ Fetches a satellite PNG from Esri's World Imagery service.
27
+
28
+ Parameters:
29
+ - lat: Latitude of the center point (decimal degrees).
30
+ - lon: Longitude of the center point (decimal degrees).
31
+ - extend: Buffer distance from center in meters (radius).
32
+ - output_path: File path to save the resulting PNG.
33
+
34
+ Attempts the highest resolution (1024x1024) first,
35
+ halving the dimensions on failure until success.
36
+ Retries up to 3 times if all size attempts fail.
37
+ """
38
+ # Compute lat/lon degree offsets using geopy
39
+ lat_offset, lon_offset = meter_offsets(lat, lon, extend)
40
+
41
+ # Compute bounding box in lon/lat
42
+ minx = lon - lon_offset
43
+ miny = lat - lat_offset
44
+ maxx = lon + lon_offset
45
+ maxy = lat + lat_offset
46
+
47
+ base_url = (
48
+ "https://server.arcgisonline.com/ArcGIS/rest/services/"
49
+ "World_Imagery/MapServer/export"
50
+ )
51
+
52
+ # Retry up to 3 times
53
+ for attempt in range(3):
54
+ logger.info(f"Attempt {attempt + 1}/3 to fetch satellite image...")
55
+
56
+ # Try descending sizes until success
57
+ size = 1024
58
+ while size >= 128:
59
+ params = {
60
+ "bbox": f"{minx},{miny},{maxx},{maxy}",
61
+ "bboxSR": "4326",
62
+ "size": f"{size},{size}",
63
+ "format": "png",
64
+ "f": "image",
65
+ }
66
+ try:
67
+ response = httpx.get(base_url, params=params, timeout=30.0)
68
+ if response.status_code == 200:
69
+ with open(output_path, "wb") as f:
70
+ f.write(response.content)
71
+ logger.info(f"Saved Esri image to {output_path} ({size}x{size})")
72
+ return
73
+ else:
74
+ logger.info(
75
+ f"Failed at size {size} (status {response.status_code}), trying {size // 2}"
76
+ )
77
+ size //= 2
78
+ except Exception as e:
79
+ logger.error(f"Network error at size {size}: {e}, trying {size // 2}")
80
+ size //= 2
81
+
82
+ # If this attempt failed for all sizes, log and continue to next attempt
83
+ if attempt < 2: # Don't print this message on the last attempt
84
+ logger.info(f"Attempt {attempt + 1} failed for all sizes, retrying...")
85
+
86
+ # If all attempts fail
87
+ logger.warning("Unable to fetch Esri imagery: all retry attempts failed.")
src/prompt/preprocess/keyframe_extract.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from google.cloud import videointelligence_v1 as vi
7
+ from scipy.spatial.distance import cdist
8
+ from sklearn.metrics import silhouette_score
9
+
10
+ # Set up logger
11
+ logger = logging.getLogger("uvicorn.error")
12
+
13
+
14
+ def detect_shot_intervals_local(video_path: str) -> list[tuple[float, float]]:
15
+ logger.info(f"Detecting shot intervals for video: {video_path}")
16
+ client = vi.VideoIntelligenceServiceClient()
17
+ with open(video_path, "rb") as f:
18
+ input_content = f.read()
19
+
20
+ op = client.annotate_video(
21
+ request={
22
+ "input_content": input_content,
23
+ "features": [vi.Feature.SHOT_CHANGE_DETECTION],
24
+ }
25
+ )
26
+ response = op.result(timeout=300)
27
+ if not response or not response.annotation_results:
28
+ logger.error("No annotation_results found in video intelligence response.")
29
+ return []
30
+ result = response.annotation_results[0]
31
+ intervals = []
32
+ for shot in result.shot_annotations:
33
+ start = (
34
+ shot.start_time_offset.seconds + shot.start_time_offset.microseconds / 1e6
35
+ )
36
+ end = shot.end_time_offset.seconds + shot.end_time_offset.microseconds / 1e6
37
+ intervals.append((start, end))
38
+ logger.info(f"Detected {len(intervals)} shot intervals.")
39
+ return intervals
40
+
41
+
42
+ def color_histogram(img: np.ndarray) -> np.ndarray:
43
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
44
+ hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
45
+ return cv2.normalize(hist, hist).flatten()
46
+
47
+
48
+ def sample_frames_per_shot(
49
+ video_path: str, start: float, end: float, step: float = 1.0
50
+ ) -> list[np.ndarray]:
51
+ # logger.info(f"Sampling frames from {start:.2f}s to {end:.2f}s every {step:.2f}s")
52
+ cap = cv2.VideoCapture(video_path)
53
+ frames = []
54
+ t = start
55
+ while t < end:
56
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
57
+ ret, frame = cap.read()
58
+ if not ret:
59
+ logger.warning(f"Failed to read frame at {t:.2f}s")
60
+ break
61
+ frames.append(frame)
62
+ t += step
63
+ cap.release()
64
+ # logger.info(f"Sampled {len(frames)} frames for shot interval.")
65
+ return frames
66
+
67
+
68
+ def kmeans_init(features: np.ndarray):
69
+ n, _ = features.shape
70
+ k = int(np.sqrt(n)) or 1
71
+ idx = np.random.choice(n, k, replace=False)
72
+ centers = features[idx]
73
+ clusters = np.argmin(cdist(features, centers), axis=1)
74
+ return clusters, centers
75
+
76
+
77
+ def kmeans_silhouette(features: np.ndarray):
78
+ k = max(int(np.sqrt(len(features))), 2)
79
+ best_k, best_score = k, -1
80
+ clusters, centers = kmeans_init(features)
81
+ best_centers = centers.copy()
82
+ while k > 2:
83
+ d = cdist(centers, centers)
84
+ np.fill_diagonal(d, np.inf)
85
+ i, j = np.unravel_index(np.argmin(d), d.shape)
86
+ clusters = np.where(clusters == j, i, clusters)
87
+ clusters = np.where(clusters > j, clusters - 1, clusters)
88
+ new_centers = []
89
+ for cid in range(k - 1):
90
+ cluster_feats = features[clusters == cid]
91
+ if cluster_feats.size == 0:
92
+ continue
93
+ mean_vec = np.mean(cluster_feats, axis=0)
94
+ idx_close = np.argmin(np.linalg.norm(cluster_feats - mean_vec, axis=1))
95
+ new_centers.append(cluster_feats[idx_close])
96
+ centers = new_centers
97
+ k -= 1
98
+ if len(np.unique(clusters)) > 1:
99
+ score = silhouette_score(features, clusters)
100
+ if score > best_score:
101
+ best_score, best_k = score, k
102
+ best_centers = centers.copy()
103
+ center_indices = []
104
+ for c in best_centers:
105
+ matches = np.where((features == c).all(axis=1))[0]
106
+ if matches.size > 0:
107
+ center_indices.append(int(matches[0]))
108
+ # logger.info(f"KMeans silhouette: best_k={best_k}, best_score={best_score:.4f}")
109
+ return best_k, best_centers, center_indices
110
+
111
+
112
+ def redundancy_filter(
113
+ video_path: str, indices: list[int], threshold: float
114
+ ) -> list[int]:
115
+ # logger.info(f"Filtering redundant frames with threshold {threshold}")
116
+ histos = []
117
+ cap = cv2.VideoCapture(video_path)
118
+ for idx in indices:
119
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
120
+ ret, frame = cap.read()
121
+ if ret:
122
+ histos.append(color_histogram(frame))
123
+ cap.release()
124
+ keep = []
125
+ for i, h in enumerate(histos):
126
+ if not any(
127
+ np.dot(h, nh) / (np.linalg.norm(h) * np.linalg.norm(nh)) > threshold
128
+ for nh in histos[:i]
129
+ ):
130
+ keep.append(indices[i])
131
+ # logger.info(f"Filtered down to {len(keep)} non-redundant frames.")
132
+ return keep
133
+
134
+
135
+ def extract_and_save_keyframes(
136
+ video_path: str,
137
+ output_dir: str,
138
+ start_index: int = 0,
139
+ step: float = 1.0,
140
+ threshold: float = 0.7,
141
+ k_min: int = 2,
142
+ k_max: int = 8,
143
+ ) -> int:
144
+ logger.info(f"Starting keyframe extraction for {video_path}")
145
+ os.makedirs(output_dir, exist_ok=True)
146
+
147
+ # Get FPS to convert seconds to frame indices
148
+ cap_meta = cv2.VideoCapture(video_path)
149
+ video_fps = cap_meta.get(cv2.CAP_PROP_FPS) or 1.0
150
+ cap_meta.release()
151
+
152
+ intervals = detect_shot_intervals_local(video_path)
153
+ cap = cv2.VideoCapture(video_path)
154
+ output_idx = start_index
155
+
156
+ for shot_idx, (start, end) in enumerate(intervals):
157
+ # logger.info(
158
+ # f"Processing shot {shot_idx + 1}/{len(intervals)}: {start:.2f}s to {end:.2f}s"
159
+ # )
160
+
161
+ # Sample frames & extract features
162
+ frames = sample_frames_per_shot(video_path, start, end, step)
163
+ feats = (
164
+ np.vstack([color_histogram(f) for f in frames])
165
+ if frames
166
+ else np.empty((0,))
167
+ )
168
+
169
+ # Determine intra-shot keyframe indices
170
+ if feats.size < k_min or feats.ndim == 1:
171
+ idxs = list(range(len(frames)))
172
+ else:
173
+ _, centers, cidxs = kmeans_silhouette(feats)
174
+ idxs = cidxs
175
+
176
+ # Map to global frame numbers and dedupe
177
+ global_idxs = [int(start * video_fps) + i for i in idxs]
178
+ filtered = redundancy_filter(video_path, global_idxs, threshold)
179
+
180
+ # Save each keyframe sequentially into output_dir
181
+ for frame_no in filtered:
182
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
183
+ ret, frame = cap.read()
184
+ if not ret:
185
+ continue
186
+ out_path = os.path.join(output_dir, f"image_{output_idx:03d}.jpg")
187
+ cv2.imwrite(out_path, frame)
188
+ output_idx += 1
189
+ logger.info(
190
+ f"Shot {shot_idx + 1}: saved {len(filtered)} keyframes. Total so far: {output_idx}"
191
+ )
192
+
193
+ cap.release()
194
+ logger.info(f"Extraction complete. Total frames saved: {output_idx}")
195
+ return output_idx
src/prompt/preprocess/video_transcribe.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import torchvision.transforms as T
12
+ import whisper
13
+ from moviepy import *
14
+ from open_clip import create_model_and_transforms
15
+
16
+ logger = logging.getLogger("uvicorn.error")
17
+
18
+
19
+ def extract_audio(video_path: str, output_dir: str) -> str:
20
+ """
21
+ Extract audio from video and save as WAV file.
22
+
23
+ Args:
24
+ video_path (str): Path to input video file.
25
+ output_dir (str): Directory to save the audio file.
26
+
27
+ Returns:
28
+ str: Path to the saved audio file.
29
+ """
30
+ video_name = Path(video_path).stem
31
+ audio_path = Path(output_dir) / f"{video_name}.wav"
32
+
33
+ try:
34
+ video = VideoFileClip(str(video_path))
35
+ audio = video.audio
36
+
37
+ if audio is not None:
38
+ audio.write_audiofile(str(audio_path), logger="bar")
39
+ audio.close()
40
+ video.close()
41
+
42
+ if not audio_path.exists():
43
+ raise RuntimeError("Audio file was not created")
44
+
45
+ return str(audio_path)
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error extracting audio: {str(e)}")
49
+ return ""
50
+
51
+
52
+ def transcribe_audio(audio_path: str, model_name: str = "base") -> str:
53
+ """
54
+ Transcribe audio using Whisper.
55
+
56
+ Args:
57
+ audio_path (str): Path to the audio file.
58
+ model_name (str): Whisper model name.
59
+
60
+ Returns:
61
+ str: Transcription text.
62
+ """
63
+ try:
64
+ model = whisper.load_model(model_name)
65
+ result = model.transcribe(str(audio_path), fp16=False, verbose=False)
66
+ return str(result.get("text", "")).strip()
67
+
68
+ except Exception as e:
69
+ raise RuntimeError(f"Error transcribing audio: {str(e)}")
70
+
71
+
72
+ def transcribe_video(
73
+ video_path: str,
74
+ output_dir: str = "g3/data/prompt_data/audio",
75
+ model_name: str = "base",
76
+ ):
77
+ """
78
+ Transcribe video by extracting audio and then transcribing it.
79
+
80
+ Args:
81
+ video_path (str): Path to the video file.
82
+ output_dir (str): Directory to save the audio file.
83
+ model_name (str): Whisper model name.
84
+
85
+ Returns:
86
+ str: Path to the saved transcription text file.
87
+ """
88
+ audio_path = extract_audio(video_path, output_dir)
89
+ if not audio_path:
90
+ logger.error("Audio extraction failed. No audio file created.")
91
+ return
92
+
93
+ logger.info(f"Audio extracted to: {audio_path}")
94
+ transcript_text = transcribe_audio(audio_path, model_name=model_name)
95
+
96
+ transcript_path = Path(output_dir) / f"{Path(video_path).stem}_transcript.txt"
97
+ with open(transcript_path, "w", encoding="utf-8") as f:
98
+ f.write(transcript_text)
99
+ logger.info(f"Transcript saved to: {transcript_path}")
100
+
101
+
102
+ def transcribe_video_directory(
103
+ video_dir: str,
104
+ output_dir: str = "g3/data/prompt_data/audio",
105
+ model_name: str = "base",
106
+ ):
107
+ """
108
+ Transcribe all videos in a directory.
109
+
110
+ Args:
111
+ video_dir (str): Directory containing video files.
112
+ output_dir (str): Directory to save the audio and transcript files.
113
+ model_name (str): Whisper model name.
114
+
115
+ Returns:
116
+ None
117
+ """
118
+ video_extensions = {".mp4", ".avi", ".mov", ".mkv"}
119
+ os.makedirs(output_dir, exist_ok=True)
120
+
121
+ video_files = [
122
+ f
123
+ for f in Path(video_dir).glob("*")
124
+ if f.is_file() and f.suffix.lower() in video_extensions
125
+ ]
126
+
127
+ if not video_files:
128
+ logger.info(f"No video files found in directory: {video_dir}")
129
+
130
+ for video_file in video_files:
131
+ logger.info(f"Processing video: {video_file}")
132
+ transcribe_video(str(video_file), output_dir, model_name=model_name)
src/prompt/search/image_search.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import logging
4
+ import os
5
+ import time
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from pathlib import Path
8
+ from threading import Lock
9
+
10
+ import requests
11
+ from google.cloud import vision
12
+ from requests.adapters import HTTPAdapter
13
+ from urllib3.util.retry import Retry
14
+
15
+ logger = logging.getLogger("uvicorn.error")
16
+
17
+ # GOOGLE CLOUD VISION API
18
+
19
+
20
+ def annotate(path: str) -> vision.WebDetection:
21
+ """Returns web annotations given the path to an image.
22
+
23
+ Args:
24
+ path: path to the input image.
25
+
26
+ Returns:
27
+ An WebDetection object with relevant information of the
28
+ image from the internet (i.e., the annotations).
29
+ """
30
+ client = vision.ImageAnnotatorClient()
31
+
32
+ if path.startswith("http") or path.startswith("gs:"):
33
+ image = vision.Image()
34
+ image.source.image_uri = path
35
+
36
+ else:
37
+ with open(path, "rb") as image_file:
38
+ content = image_file.read()
39
+
40
+ image = vision.Image(content=content)
41
+
42
+ response = client.annotate_image(
43
+ {
44
+ "image": image,
45
+ "features": [{"type_": vision.Feature.Type.WEB_DETECTION}],
46
+ }
47
+ )
48
+ return response.web_detection
49
+
50
+
51
+ def annotate_directory(directory: str) -> list[vision.WebDetection]:
52
+ """
53
+ Perform web detection on all image files in the given directory in batches of 16.
54
+
55
+ Args:
56
+ directory (str): Path to the directory containing image files.
57
+
58
+ Returns:
59
+ list[vision.WebDetection]: List of WebDetection objects for each image.
60
+ """
61
+ client = vision.ImageAnnotatorClient()
62
+
63
+ # Collect all image files first
64
+ image_files = []
65
+ for file_name in os.listdir(directory):
66
+ file_path = os.path.join(directory, file_name)
67
+ if os.path.isfile(file_path) and file_name.lower().endswith(
68
+ (".jpg", ".jpeg", ".png", ".bmp", ".gif")
69
+ ):
70
+ image_files.append(file_path)
71
+
72
+ all_web_detections = []
73
+ batch_size = 16 # Google Vision API batch limit
74
+
75
+ # Process images in batches of 16
76
+ for i in range(0, len(image_files), batch_size):
77
+ batch_files = image_files[i : i + batch_size]
78
+ logger.info(
79
+ f"Processing batch {i // batch_size + 1}/{(len(image_files) + batch_size - 1) // batch_size} ({len(batch_files)} images)..."
80
+ )
81
+
82
+ # Prepare batch requests
83
+ image_requests = []
84
+ for file_path in batch_files:
85
+ try:
86
+ with open(file_path, "rb") as image_file:
87
+ content = image_file.read()
88
+ image = vision.Image(content=content)
89
+ image_requests.append(image)
90
+ except Exception as e:
91
+ logger.warning(f"⚠️ Failed to read image {file_path}: {e}")
92
+ # Add a placeholder to maintain order
93
+ image_requests.append(None)
94
+
95
+ # Filter out None values and keep track of valid indices
96
+ valid_requests = []
97
+ valid_indices = []
98
+ for idx, request in enumerate(image_requests):
99
+ if request is not None:
100
+ valid_requests.append(request)
101
+ valid_indices.append(idx)
102
+
103
+ if not valid_requests:
104
+ logger.warning(f"⚠️ No valid images in batch {i // batch_size + 1}")
105
+ continue
106
+
107
+ try:
108
+ # Make batch API call
109
+ responses = client.batch_annotate_images(
110
+ requests=[
111
+ vision.AnnotateImageRequest(
112
+ image=image,
113
+ features=[
114
+ vision.Feature(type=vision.Feature.Type.WEB_DETECTION)
115
+ ],
116
+ )
117
+ for image in valid_requests
118
+ ]
119
+ ).responses
120
+
121
+ # Process responses and maintain order
122
+ batch_detections: list[vision.WebDetection | None] = [None] * len(
123
+ batch_files
124
+ )
125
+ for response_idx, global_idx in enumerate(valid_indices):
126
+ if (
127
+ response_idx < len(responses)
128
+ and responses[response_idx].web_detection
129
+ ):
130
+ batch_detections[global_idx] = responses[response_idx].web_detection
131
+
132
+ # Add to results (filter out None values)
133
+ all_web_detections.extend(
134
+ [det for det in batch_detections if det is not None]
135
+ )
136
+
137
+ except Exception as e:
138
+ logger.warning(f"⚠️ Batch {i // batch_size + 1} failed: {e}")
139
+ continue
140
+
141
+ logger.info(
142
+ f"✅ Successfully processed {len(all_web_detections)} images out of {len(image_files)} total"
143
+ )
144
+ return all_web_detections
145
+
146
+
147
+ def parse_web_detection(annotations: vision.WebDetection) -> dict:
148
+ """Returns detected features in the provided web annotations as a dict."""
149
+ result = {
150
+ "pages_with_matching_images": [],
151
+ "full_matching_images": [],
152
+ "partial_matching_images": [],
153
+ "web_entities": [],
154
+ }
155
+ if annotations.pages_with_matching_images:
156
+ for page in annotations.pages_with_matching_images:
157
+ result["pages_with_matching_images"].append(page.url)
158
+ if annotations.full_matching_images:
159
+ for image in annotations.full_matching_images:
160
+ result["full_matching_images"].append(image.url)
161
+ if annotations.partial_matching_images:
162
+ for image in annotations.partial_matching_images:
163
+ result["partial_matching_images"].append(image.url)
164
+ if annotations.web_entities:
165
+ for entity in annotations.web_entities:
166
+ result["web_entities"].append(
167
+ {"score": entity.score, "description": entity.description}
168
+ )
169
+ return result
170
+
171
+
172
+ def get_image_links_vision(annotations: vision.WebDetection) -> list[str]:
173
+ """Extracts image links from web detection annotations."""
174
+ links = []
175
+ if annotations.pages_with_matching_images:
176
+ for page in annotations.pages_with_matching_images:
177
+ links.append(page.url)
178
+ if not links and annotations.full_matching_images:
179
+ # Fallback to full matching images if no pages found
180
+ for image in annotations.full_matching_images:
181
+ links.append(image.url)
182
+ if not links and annotations.partial_matching_images:
183
+ # Fallback to partial matching images if no full matches found
184
+ for image in annotations.partial_matching_images:
185
+ links.append(image.url)
186
+ return links
187
+
188
+
189
+ # SCRAPING DOG API
190
+ def upload_image_to_imgbb(image_path: str, api_key: str) -> str:
191
+ """Upload image to imgbb with automatic retry on transient errors."""
192
+
193
+ # Encode the image
194
+ try:
195
+ with open(image_path, "rb") as f:
196
+ image_data = base64.b64encode(f.read()).decode("utf-8")
197
+ except FileNotFoundError:
198
+ raise FileNotFoundError(f"Image file not found: {image_path}")
199
+ except Exception as e:
200
+ raise Exception(f"Error reading image file: {e}")
201
+
202
+ payload = {"key": api_key, "image": image_data}
203
+ imgbb_url = "https://api.imgbb.com/1/upload"
204
+
205
+ # Configure session with retry logic
206
+ session = requests.Session()
207
+ retry_strategy = Retry(
208
+ total=5,
209
+ backoff_factor=1,
210
+ status_forcelist=[429, 500, 502, 503, 504],
211
+ allowed_methods=["POST"],
212
+ raise_on_status=False,
213
+ )
214
+ adapter = HTTPAdapter(max_retries=retry_strategy)
215
+ session.mount("https://", adapter)
216
+ session.mount("http://", adapter)
217
+
218
+ try:
219
+ resp = session.post(imgbb_url, data=payload, timeout=30)
220
+ resp.raise_for_status()
221
+ result = resp.json()
222
+ if result.get("success"):
223
+ return result["data"]["url"]
224
+ else:
225
+ raise Exception(
226
+ f"imgbb upload failed: {result.get('error', 'Unknown error')}"
227
+ )
228
+ except requests.exceptions.RequestException as e:
229
+ raise Exception(f"Failed to upload after retries: {e}")
230
+
231
+
232
+ def search_with_scrapingdog_lens(
233
+ image_path: str, imgbb_key: str, scrapingdog_key: str
234
+ ) -> dict:
235
+ """
236
+ Uploads an image to imgbb, then queries ScrapingDog's Google Lens API with 3 retries.
237
+ """
238
+ try:
239
+ image_url = upload_image_to_imgbb(image_path, imgbb_key)
240
+ logger.info(f"Image uploaded to ImgBB: {image_url}")
241
+
242
+ lens_url = f"https://lens.google.com/uploadbyurl?url={image_url}"
243
+ params = {
244
+ "api_key": scrapingdog_key,
245
+ "url": lens_url,
246
+ "visual_matches": "true",
247
+ "exact_matches": "true",
248
+ }
249
+
250
+ # Retry logic - 3 attempts
251
+ for attempt in range(3):
252
+ try:
253
+ resp = requests.get(
254
+ "https://api.scrapingdog.com/google_lens", params=params, timeout=60
255
+ )
256
+ resp.raise_for_status()
257
+ return resp.json()
258
+ except requests.exceptions.RequestException as e:
259
+ logger.warning(
260
+ f"⚠️ ScrapingDog attempt {attempt + 1}/3 failed for {os.path.basename(image_path)}: {e}"
261
+ )
262
+ if attempt < 2: # Don't sleep on the last attempt
263
+ time.sleep(2) # Wait 2 seconds before retrying
264
+ continue
265
+
266
+ # All retries failed
267
+ logger.error(
268
+ f"❌ All 3 ScrapingDog attempts failed for {os.path.basename(image_path)}"
269
+ )
270
+ return {"lens_results": []}
271
+
272
+ except Exception as e:
273
+ logger.warning(f"⚠️ ScrapingDog API unexpected error for {image_path}: {e}")
274
+ return {"lens_results": []}
275
+
276
+
277
+ def get_image_links_scrapingdog(search_results: dict, n_results: int = 5) -> list[str]:
278
+ """Get image links from Scrapingdog Lens API."""
279
+ return [result["link"] for result in search_results.get("lens_results", [])][
280
+ :n_results
281
+ ]
282
+
283
+
284
+ def process_scrapingdog_only(image_path: str) -> list[str]:
285
+ """Process a single image with ScrapingDog API only."""
286
+ try:
287
+ scrapingdog_search_result = search_with_scrapingdog_lens(
288
+ image_path=image_path,
289
+ imgbb_key=os.environ["IMGBB_API_KEY"],
290
+ scrapingdog_key=os.environ["SCRAPINGDOG_API_KEY"],
291
+ )
292
+ scrapingdog_result = get_image_links_scrapingdog(
293
+ scrapingdog_search_result, n_results=5
294
+ )
295
+
296
+ with print_lock:
297
+ logger.info(
298
+ f"✅ ScrapingDog completed for {os.path.basename(image_path)} - {len(scrapingdog_result)} links"
299
+ )
300
+
301
+ return scrapingdog_result
302
+ except Exception as e:
303
+ with print_lock:
304
+ logger.error(
305
+ f"❌ ScrapingDog error for {os.path.basename(image_path)}: {e}"
306
+ )
307
+ return []
308
+
309
+
310
+ # Thread-safe print lock
311
+ print_lock = Lock()
312
+
313
+
314
+ def process_single_image(image_path: str, imgbb_key: str, scrapingdog_key: str) -> dict:
315
+ """
316
+ Process a single image with both Vision API and ScrapingDog API.
317
+
318
+ Args:
319
+ image_path: Path to the image file
320
+ imgbb_key: ImgBB API key
321
+ scrapingdog_key: ScrapingDog API key
322
+
323
+ Returns:
324
+ Dictionary containing the results for this image
325
+ """
326
+ try:
327
+ # Vision API processing
328
+ annotations = annotate(image_path)
329
+ vision_result = get_image_links_vision(annotations)
330
+
331
+ # ScrapingDog API processing
332
+ scrapingdog_search_result = search_with_scrapingdog_lens(
333
+ image_path=image_path, imgbb_key=imgbb_key, scrapingdog_key=scrapingdog_key
334
+ )
335
+ scrapingdog_result = get_image_links_scrapingdog(
336
+ scrapingdog_search_result, n_results=5
337
+ )
338
+ # scrapingdog_result = []
339
+
340
+ result = {
341
+ "image_path": os.path.basename(image_path),
342
+ "vision_result": vision_result,
343
+ "scrapingdog_result": scrapingdog_result,
344
+ }
345
+
346
+ with print_lock:
347
+ logger.info(f"✅ Completed processing {os.path.basename(image_path)}")
348
+
349
+ return result
350
+
351
+ except Exception as e:
352
+ with print_lock:
353
+ logger.error(f"❌ Error processing {os.path.basename(image_path)}: {e}")
354
+ return {
355
+ "image_path": os.path.basename(image_path),
356
+ "vision_result": [],
357
+ "scrapingdog_result": [],
358
+ "error": str(e),
359
+ }
360
+
361
+
362
+ def image_search_directory(
363
+ directory: str,
364
+ output_dir: str = "g3/data/prompt_data",
365
+ filename: str = "metadata.json",
366
+ imgbb_key: str = "YOUR_IMGBB_API_KEY",
367
+ scrapingdog_key: str = "YOUR_SCRAPINGDOG_API_KEY",
368
+ max_workers: int = 4,
369
+ target_links: int = 20,
370
+ ) -> None:
371
+ """
372
+ Perform web detection with a two-phase approach:
373
+ 1. Run Vision API on all images first using annotate_directory
374
+ 2. If total unique links < target_links, run ScrapingDog on images until target is reached
375
+
376
+ Args:
377
+ directory (str): Path to the directory containing image files.
378
+ output_dir (str): Directory to save the JSON output.
379
+ filename (str): Name of the JSON file to save the results.
380
+ imgbb_key (str): ImgBB API key for image uploading.
381
+ scrapingdog_key (str): ScrapingDog API key for lens search.
382
+ max_workers (int): Maximum number of parallel workers.
383
+ target_links (int): Target number of unique links to collect.
384
+
385
+ Returns:
386
+ None
387
+ """
388
+ EXCLUDE_DOMAIN = [
389
+ "youtube.com",
390
+ ]
391
+ # Get all image files
392
+ image_files = []
393
+ for file_name in os.listdir(directory):
394
+ file_path = os.path.join(directory, file_name)
395
+ if os.path.isfile(file_path) and file_name.lower().endswith(
396
+ (".jpg", ".jpeg", ".png", ".bmp", ".gif")
397
+ ):
398
+ image_files.append(file_path)
399
+
400
+ if not image_files:
401
+ logger.info("No image files found in the directory.")
402
+ return
403
+
404
+ logger.info(
405
+ f"Found {len(image_files)} image files. Target: {target_links} unique links"
406
+ )
407
+
408
+ # Phase 1: Run Vision API on all images using annotate_directory
409
+ logger.info("🔍 Phase 1: Running Vision API on all images...")
410
+ all_links = set()
411
+ vision_links_count = 0
412
+
413
+ try:
414
+ # Use the existing annotate_directory function for batch processing
415
+ web_detections = annotate_directory(directory)
416
+
417
+ # Extract links from all web detections
418
+ for detection in web_detections:
419
+ links = get_image_links_vision(detection)
420
+ # Filter out links from excluded domains
421
+ links = [
422
+ link for link in links if not any(domain in link for domain in EXCLUDE_DOMAIN)
423
+ ]
424
+ all_links.update(links) # Add to set (automatically deduplicates)
425
+
426
+ vision_links_count = len(all_links)
427
+ logger.info(
428
+ f"✅ Phase 1 complete: {vision_links_count} unique links from Vision API"
429
+ )
430
+
431
+ except Exception as e:
432
+ logger.error(f"❌ Vision API processing failed: {e}")
433
+ all_links = set()
434
+ vision_links_count = 0
435
+
436
+ # Phase 2: Run ScrapingDog if needed
437
+ scrapingdog_links_count = 0
438
+
439
+ if len(all_links) < target_links:
440
+ needed_links = target_links - len(all_links)
441
+ logger.info(
442
+ f"🔍 Phase 2: Need {needed_links} more links. Running ScrapingDog..."
443
+ )
444
+
445
+ # Check if API keys are available
446
+ if (
447
+ imgbb_key == "YOUR_IMGBB_API_KEY"
448
+ or scrapingdog_key == "YOUR_SCRAPINGDOG_API_KEY"
449
+ ):
450
+ logger.warning("⚠️ ScrapingDog API keys not available. Skipping Phase 2.")
451
+ else:
452
+ scrapingdog_completed = 0
453
+
454
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
455
+ # Submit ScrapingDog tasks for all images
456
+ future_to_image = {
457
+ executor.submit(process_scrapingdog_only, image_path): image_path
458
+ for image_path in image_files
459
+ }
460
+
461
+ # Collect ScrapingDog results until we have enough links
462
+ for future in as_completed(future_to_image):
463
+ image_path = future_to_image[future]
464
+ try:
465
+ result_links = future.result()
466
+ filtered_links = [
467
+ link for link in result_links if not any(domain in link for domain in EXCLUDE_DOMAIN)
468
+ ]
469
+ initial_count = len(filtered_links)
470
+ all_links.update(filtered_links) # Add new links to the main set
471
+ scrapingdog_links_count += (
472
+ len(all_links) - initial_count
473
+ ) # Count new unique links added
474
+ scrapingdog_completed += 1
475
+
476
+ with print_lock:
477
+ logger.info(
478
+ f"ScrapingDog Progress: {scrapingdog_completed}/{len(image_files)} images, "
479
+ f"{scrapingdog_links_count} new ScrapingDog links, {len(all_links)} total unique"
480
+ )
481
+
482
+ # Stop early if we have enough links
483
+ if len(all_links) >= target_links:
484
+ logger.info(
485
+ f"🎯 Target reached! {len(all_links)} >= {target_links} links"
486
+ )
487
+ # Cancel remaining futures
488
+ for remaining_future in future_to_image:
489
+ if not remaining_future.done():
490
+ remaining_future.cancel()
491
+ break
492
+
493
+ except Exception as e:
494
+ with print_lock:
495
+ logger.error(
496
+ f"❌ Failed ScrapingDog for {os.path.basename(image_path)}: {e}"
497
+ )
498
+ scrapingdog_completed += 1
499
+
500
+ # Prepare final results
501
+ total_unique_links = len(all_links)
502
+ all_links = list(all_links)[:target_links]
503
+ results = {
504
+ "all_links": all_links,
505
+ "total_unique_links": total_unique_links,
506
+ "target_achieved": total_unique_links >= target_links,
507
+ "summary": {
508
+ "images_processed": len(image_files),
509
+ "vision_links": vision_links_count,
510
+ "scrapingdog_links": scrapingdog_links_count,
511
+ "total_unique_links": total_unique_links,
512
+ "target_links": target_links,
513
+ },
514
+ }
515
+
516
+ # Ensure the output directory exists
517
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
518
+
519
+ # Save results to JSON file
520
+ out_path = Path(output_dir) / filename
521
+ with open(out_path, "w", encoding="utf-8") as f:
522
+ json.dump(results, f, ensure_ascii=False, indent=2)
523
+
524
+ logger.info(
525
+ f"✅ Saved results to {out_path}\n"
526
+ f"📊 Summary: {vision_links_count} Vision + {scrapingdog_links_count} ScrapingDog = {total_unique_links} total unique links"
527
+ )
src/prompt/search/index_search.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from pathlib import Path
6
+ from threading import Lock
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from PIL import Image
12
+
13
+ logger = logging.getLogger("uvicorn.error")
14
+
15
+ # Thread-safe lock for logging
16
+ print_lock = Lock()
17
+
18
+
19
+ def search_index(model, rgb_image, device, index, top_k=20):
20
+ """
21
+ Search FAISS index for similar and dissimilar coordinates using image embeddings.
22
+
23
+ Args:
24
+ model: Vision model used for embedding generation.
25
+ rgb_image: PIL RGB Image.
26
+ device: Device to run the model on (e.g., "cuda" or "cpu").
27
+ index: FAISS index for searching.
28
+ top_k (int): Number of top results to return.
29
+
30
+ Returns:
31
+ tuple: (D, I, D_reverse, I_reverse) - distances and indices for positive and negative embeddings.
32
+ """
33
+ # logger.info("Searching FAISS index...")
34
+ image = model.vision_processor(images=rgb_image, return_tensors="pt")[
35
+ "pixel_values"
36
+ ].reshape(-1, 224, 224)
37
+ image = image.unsqueeze(0).to(device) # Add batch dimension
38
+
39
+ with torch.no_grad():
40
+ vision_output = model.vision_model(image)[1]
41
+ image_embeds = model.vision_projection(vision_output)
42
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
43
+
44
+ image_text_embeds = model.vision_projection_else_1(
45
+ model.vision_projection(vision_output)
46
+ )
47
+ image_text_embeds = image_text_embeds / image_text_embeds.norm(
48
+ p=2, dim=-1, keepdim=True
49
+ )
50
+
51
+ image_location_embeds = model.vision_projection_else_2(
52
+ model.vision_projection(vision_output)
53
+ )
54
+ image_location_embeds = image_location_embeds / image_location_embeds.norm(
55
+ p=2, dim=-1, keepdim=True
56
+ )
57
+
58
+ positive_image_embeds = torch.cat(
59
+ [image_embeds, image_text_embeds, image_location_embeds], dim=1
60
+ )
61
+ positive_image_embeds = (
62
+ positive_image_embeds.cpu().detach().numpy().astype(np.float32)
63
+ )
64
+
65
+ negative_image_embeds = positive_image_embeds * (-1.0)
66
+
67
+ # Search FAISS index
68
+ D, I = index.search(positive_image_embeds, top_k)
69
+ D_reverse, I_reverse = index.search(negative_image_embeds, top_k)
70
+ return D, I, D_reverse, I_reverse
71
+
72
+
73
+ def get_gps_coordinates(I, I_reverse, database_csv_path):
74
+ """
75
+ Helper method to get GPS coordinates from database using FAISS indices.
76
+
77
+ Args:
78
+ I: FAISS indices for positive embeddings
79
+ I_reverse: FAISS indices for negative embeddings
80
+ database_csv_path (str): Path to GPS coordinates database CSV
81
+
82
+ Returns:
83
+ tuple: (candidates_gps, reverse_gps) - lists of (lat, lon) tuples
84
+ """
85
+ if I is None or I_reverse is None:
86
+ return [], []
87
+
88
+ candidate_indices = I[0]
89
+ reverse_indices = I_reverse[0]
90
+
91
+ candidates_gps = []
92
+ reverse_gps = []
93
+
94
+ try:
95
+ for chunk in pd.read_csv(
96
+ database_csv_path, chunksize=10000, usecols=["LAT", "LON"]
97
+ ):
98
+ for idx in candidate_indices:
99
+ if idx in chunk.index:
100
+ lat = float(chunk.loc[idx, "LAT"])
101
+ lon = float(chunk.loc[idx, "LON"])
102
+ candidates_gps.append((lat, lon))
103
+
104
+ for ridx in reverse_indices:
105
+ if ridx in chunk.index:
106
+ lat = float(chunk.loc[ridx, "LAT"])
107
+ lon = float(chunk.loc[ridx, "LON"])
108
+ reverse_gps.append((lat, lon))
109
+ except Exception as e:
110
+ logger.error(f"⚠️ Error loading GPS coordinates from database: {e}")
111
+
112
+ return candidates_gps, reverse_gps
113
+
114
+
115
+ def save_results_to_json(candidates_gps: list, reverse_gps: list, output_path: str):
116
+ """
117
+ Save search results to a JSON file.
118
+
119
+ Args:
120
+ results (dict): Search results to save.
121
+ output_path (str): Path to the output JSON file.
122
+ """
123
+ results = {"candidates_gps": candidates_gps, "reverse_gps": reverse_gps}
124
+ with open(output_path, "w") as json_file:
125
+ json.dump(results, json_file, indent=4)
126
+
127
+
128
+ def process_single_image(image_path, model, device, index, database_csv_path, top_k=20):
129
+ """
130
+ Process a single image for index search.
131
+
132
+ Args:
133
+ image_path: Path to the image file
134
+ model: Vision model used for embedding generation
135
+ device: Device to run the model on
136
+ index: FAISS index for searching
137
+ database_csv_path: Path to GPS coordinates database CSV
138
+ top_k: Number of top results to return
139
+
140
+ Returns:
141
+ tuple: (candidates_gps, reverse_gps) for this image
142
+ """
143
+ try:
144
+ rgb_image = Image.open(image_path).convert("RGB")
145
+ D, I, D_reverse, I_reverse = search_index(
146
+ model, rgb_image, device, index, top_k
147
+ )
148
+ candidates_gps, reverse_gps = get_gps_coordinates(
149
+ I, I_reverse, database_csv_path
150
+ )
151
+
152
+ # with print_lock:
153
+ # logger.info(
154
+ # f"✅ Processed {os.path.basename(image_path)}: {len(candidates_gps)} candidates, {len(reverse_gps)} reverse"
155
+ # )
156
+
157
+ return candidates_gps, reverse_gps
158
+ except Exception as e:
159
+ with print_lock:
160
+ logger.error(f"❌ Error processing {os.path.basename(image_path)}: {e}")
161
+ return [], []
162
+
163
+
164
+ def search_index_directory(
165
+ model,
166
+ device,
167
+ index,
168
+ image_dir,
169
+ database_csv_path,
170
+ top_k=20,
171
+ max_elements=20,
172
+ max_workers=4,
173
+ ):
174
+ """
175
+ Perform FAISS index search for all images in a directory in parallel and gradually build a prioritized set of candidates.
176
+
177
+ Args:
178
+ model: Vision model used for embedding generation.
179
+ device: Device to run the model on (e.g., "cuda" or "cpu").
180
+ index: FAISS index for searching.
181
+ image_dir (str): Path to the directory containing images.
182
+ database_csv_path (str): Path to GPS coordinates database CSV.
183
+ top_k (int): Number of top results to return for each image.
184
+ max_elements (int): Maximum number of elements in the final candidates set.
185
+ max_workers (int): Maximum number of parallel workers.
186
+
187
+ Returns:
188
+ tuple: (candidates_gps, reverse_gps) - lists of (lat, lon) tuples.
189
+ """
190
+ # Get all image paths
191
+ image_paths = [
192
+ Path(image_dir) / img
193
+ for img in os.listdir(image_dir)
194
+ if img.lower().endswith((".jpg", ".jpeg", ".png", ".bmp"))
195
+ ]
196
+
197
+ if not image_paths:
198
+ logger.warning("No images found in directory")
199
+ return [], []
200
+
201
+ logger.info(
202
+ f"🚀 Processing {len(image_paths)} images with {max_workers} parallel workers..."
203
+ )
204
+
205
+ all_candidates_gps = []
206
+ all_reverse_gps = []
207
+ completed_count = 0
208
+
209
+ # Process images in parallel
210
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
211
+ # Submit all tasks
212
+ future_to_path = {
213
+ executor.submit(
214
+ process_single_image,
215
+ image_path,
216
+ model,
217
+ device,
218
+ index,
219
+ database_csv_path,
220
+ top_k,
221
+ ): image_path
222
+ for image_path in image_paths
223
+ }
224
+
225
+ # Collect results as they complete
226
+ for future in as_completed(future_to_path):
227
+ image_path = future_to_path[future]
228
+ try:
229
+ candidates_gps, reverse_gps = future.result()
230
+ all_candidates_gps.append(candidates_gps)
231
+ all_reverse_gps.append(reverse_gps)
232
+ completed_count += 1
233
+
234
+ with print_lock:
235
+ logger.info(
236
+ f"Progress: {completed_count}/{len(image_paths)} images completed"
237
+ )
238
+
239
+ except Exception as e:
240
+ with print_lock:
241
+ logger.error(
242
+ f"❌ Failed to process {os.path.basename(image_path)}: {e}"
243
+ )
244
+ # Add empty results for failed images
245
+ all_candidates_gps.append([])
246
+ all_reverse_gps.append([])
247
+ completed_count += 1
248
+
249
+ # Build prioritized sets from all results
250
+ candidates_gps = set()
251
+ reverse_gps = set()
252
+
253
+ for priority in range(top_k):
254
+ for image_candidates_gps, image_reverse_gps in zip(
255
+ all_candidates_gps, all_reverse_gps
256
+ ):
257
+ if len(candidates_gps) < max_elements and priority < len(
258
+ image_candidates_gps
259
+ ):
260
+ candidates_gps.add(image_candidates_gps[priority])
261
+ if len(reverse_gps) < max_elements and priority < len(image_reverse_gps):
262
+ reverse_gps.add(image_reverse_gps[priority])
263
+
264
+ if len(candidates_gps) >= max_elements and len(reverse_gps) >= max_elements:
265
+ break
266
+
267
+ logger.info(
268
+ f"🎯 Final results: {len(candidates_gps)} candidates, {len(reverse_gps)} reverse GPS coordinates"
269
+ )
270
+
271
+ return list(candidates_gps), list(reverse_gps)
src/prompt/search/text_search.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import Optional
6
+
7
+ import httpx
8
+ from dotenv import load_dotenv
9
+
10
+ logger = logging.getLogger("uvicorn.error")
11
+
12
+
13
+ def retry_request(func, max_retries=3, base_delay=2.0):
14
+ """
15
+ Retry a function with exponential backoff for timeout and connection errors.
16
+
17
+ Args:
18
+ func: Function to retry
19
+ max_retries: Maximum number of retry attempts
20
+ base_delay: Base delay for exponential backoff
21
+
22
+ Returns:
23
+ Result of the function call
24
+
25
+ Raises:
26
+ Last exception if all retries fail
27
+ """
28
+ for attempt in range(max_retries):
29
+ try:
30
+ return func()
31
+ except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.TimeoutException) as e:
32
+ if attempt < max_retries - 1:
33
+ delay = base_delay * (2**attempt)
34
+ logger.warning(
35
+ f"⚠️ Timeout error (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s..."
36
+ )
37
+ time.sleep(delay)
38
+ continue
39
+ else:
40
+ logger.error(
41
+ f"❌ Max retries ({max_retries}) exceeded for timeout error."
42
+ )
43
+ raise e
44
+ except httpx.HTTPStatusError as e:
45
+ if e.response.status_code in [500, 502, 503, 504]:
46
+ if attempt < max_retries - 1:
47
+ delay = base_delay * (2**attempt)
48
+ logger.warning(
49
+ f"⚠️ Server error {e.response.status_code} (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s..."
50
+ )
51
+ time.sleep(delay)
52
+ continue
53
+ logger.error(f"❌ HTTP error {e.response.status_code}: {e}")
54
+ raise e
55
+ except Exception as e:
56
+ logger.error(f"❌ Unexpected error: {e}")
57
+ raise e
58
+
59
+ # Should never reach here
60
+ raise RuntimeError("Retry logic failed")
61
+
62
+
63
+ def extension_from_content_type(content_type: str) -> str:
64
+ # Define allowed image types
65
+ allowed_types = {
66
+ "image/png": "png",
67
+ "image/jpeg": "jpg",
68
+ "image/jpg": "jpg",
69
+ "image/webp": "webp",
70
+ "image/heic": "heic",
71
+ "image/heif": "heif",
72
+ }
73
+
74
+ # Normalize content type (remove charset, etc.)
75
+ content_type = content_type.split(";")[0].strip().lower()
76
+
77
+ if content_type in allowed_types:
78
+ return allowed_types[content_type]
79
+ else:
80
+ raise ValueError(
81
+ f"Content type '{content_type}' is not supported. Allowed types: {list(allowed_types.keys())}"
82
+ )
83
+
84
+
85
+ def text_search_image(
86
+ query: str,
87
+ num_images: int = 5,
88
+ api_key: str | None = None,
89
+ cx: str | None = None,
90
+ output_dir: str = "g3/data/prompt_data/images",
91
+ start_index: int = 0,
92
+ ) -> list[str]:
93
+ if not api_key or not cx:
94
+ raise ValueError("GOOGLE_CLOUD_API_KEY or GOOGLE_CSE_CX not set.")
95
+
96
+ os.makedirs(output_dir, exist_ok=True)
97
+ downloaded_files: list[str] = []
98
+ start: int = 1
99
+
100
+ idx = start_index
101
+ while len(downloaded_files) < num_images:
102
+ params = {
103
+ "q": query,
104
+ "searchType": "image",
105
+ "cx": cx,
106
+ "key": api_key,
107
+ "num": min(10, num_images - len(downloaded_files)),
108
+ "start": start,
109
+ }
110
+
111
+ # Use retry logic for the API request
112
+ try:
113
+ response = retry_request(
114
+ lambda: httpx.get(
115
+ "https://customsearch.googleapis.com/customsearch/v1",
116
+ params=params,
117
+ timeout=30.0, # Increased timeout
118
+ )
119
+ )
120
+ response.raise_for_status()
121
+ except Exception as e:
122
+ logger.error(f"❌ Failed to search for images after retries: {e}")
123
+ break
124
+
125
+ results = response.json().get("items", [])
126
+
127
+ if not results:
128
+ logger.info("No more results from API")
129
+ break
130
+
131
+ for item in results:
132
+ img_url: str | None = item.get("link")
133
+ if not img_url:
134
+ continue
135
+ try:
136
+ # Use retry logic for image download
137
+ r = retry_request(lambda url=img_url: httpx.get(url, timeout=15.0))
138
+ r.raise_for_status()
139
+ content_type = r.headers.get("Content-Type", "")
140
+
141
+ # Check if content type is supported before processing
142
+ try:
143
+ ext = extension_from_content_type(content_type)
144
+ except ValueError as e:
145
+ logger.info(f"Skipping {img_url}: {e}")
146
+ continue
147
+
148
+ filename = os.path.join(output_dir, f"image_{idx:03d}.{ext}")
149
+ with open(filename, "wb") as f:
150
+ f.write(r.content)
151
+ downloaded_files.append(filename)
152
+ idx += 1
153
+ if len(downloaded_files) >= num_images:
154
+ break
155
+ except httpx.HTTPError as e:
156
+ logger.error(f"HTTP error downloading {img_url}: {e}")
157
+ except Exception as e:
158
+ logger.error(f"Failed to download {img_url}: {e}")
159
+
160
+ start += 10
161
+
162
+ return downloaded_files
163
+
164
+
165
+ def text_search_link(
166
+ query: str,
167
+ output_dir: str = "g3/data/prompt_data",
168
+ filename: str = "text_search.json",
169
+ num_results: int = 10,
170
+ api_key: Optional[str] = None,
171
+ cx: Optional[str] = None,
172
+ ) -> str:
173
+ """
174
+ Search for web links using Google Custom Search API and save results to JSON file.
175
+
176
+ Args:
177
+ query (str): Search query string
178
+ output_dir (str): Directory to save the results file
179
+ filename (str): Name of the JSON file to save results
180
+ num_results (int): Number of search results to retrieve (max 100)
181
+ api_key (Optional[str]): Google API key, defaults to environment variable
182
+ cx (Optional[str]): Custom Search Engine ID, defaults to environment variable
183
+
184
+ Returns:
185
+ str: Path to the saved JSON file
186
+
187
+ Raises:
188
+ ValueError: If API key or CX not provided
189
+ httpx.HTTPError: If API request fails
190
+ """
191
+ if not api_key:
192
+ api_key = os.getenv("GOOGLE_CLOUD_API_KEY")
193
+ if not cx:
194
+ cx = os.getenv("GOOGLE_CSE_CX")
195
+
196
+ if not api_key or not cx:
197
+ raise ValueError("GOOGLE_CLOUD_API_KEY or GOOGLE_CSE_CX not set.")
198
+
199
+ links = []
200
+ start = 1
201
+ if not query:
202
+ # Prepare final results with metadata
203
+ search_results = {"query": query, "links": links}
204
+
205
+ # Save results to JSON file
206
+ output_path = os.path.join(output_dir, filename)
207
+ with open(output_path, "w", encoding="utf-8") as f:
208
+ json.dump(search_results, f, indent=2, ensure_ascii=False)
209
+
210
+ logger.info(f"✅ Saved {len(links)} search results to: {output_path}")
211
+ return output_path
212
+
213
+ os.makedirs(output_dir, exist_ok=True)
214
+ # Google Custom Search API allows max 10 results per request
215
+ while len(links) < num_results:
216
+ remaining = num_results - len(links)
217
+ current_num = min(10, remaining)
218
+
219
+ params = {
220
+ "q": query,
221
+ "cx": cx,
222
+ "key": api_key,
223
+ "num": current_num,
224
+ "start": start,
225
+ }
226
+
227
+ try:
228
+ response = retry_request(
229
+ lambda: httpx.get(
230
+ "https://customsearch.googleapis.com/customsearch/v1",
231
+ params=params,
232
+ timeout=30.0,
233
+ )
234
+ )
235
+ response.raise_for_status()
236
+ data = response.json()
237
+
238
+ items = data.get("items", [])
239
+ if not items:
240
+ logger.info(
241
+ f"No more results available. Retrieved {len(links)} results."
242
+ )
243
+ break
244
+
245
+ links.extend([item.get("link", "") for item in items if "link" in item])
246
+
247
+ if len(links) >= num_results:
248
+ break
249
+
250
+ except httpx.HTTPError as e:
251
+ logger.error(f"HTTP error during search: {e}")
252
+ break
253
+ except Exception as e:
254
+ logger.error(f"Error during search: {e}")
255
+ break
256
+
257
+ start += 10
258
+
259
+ # Ensure we only take the first num_results links
260
+ links = links[:num_results]
261
+
262
+ # Prepare final results with metadata
263
+ search_results = {"query": query, "links": links}
264
+
265
+ # Save results to JSON file
266
+ output_path = os.path.join(output_dir, filename)
267
+ with open(output_path, "w", encoding="utf-8") as f:
268
+ json.dump(search_results, f, indent=2, ensure_ascii=False)
269
+
270
+ logger.info(f"✅ Saved {len(links)} search results to: {output_path}")
271
+ return output_path
src/prompt/template.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DIVERSIFICATION_PROMPT = """
2
+ You are an expert in geo-localization. Analyze the image and determine the most precise possible location—ideally identifying the exact building, landmark, or facility, not just the city.
3
+ Examine all provided content links in detail, using both textual and visual clues to support your conclusion.
4
+ Use only the provided links for evidence. Any additional links must directly support specific visual observations (e.g., satellite imagery or publicly available street-level photos of the same location).
5
+ Return your final answer as geographic coordinates.
6
+
7
+ {prompt_data}
8
+
9
+ Respond with **only** the following JSON structure (no extra text, markdown, or comments):
10
+
11
+ {{
12
+ "latitude": float,
13
+ "longitude": float,
14
+ "location": string,
15
+ "evidence": [
16
+ {{
17
+ "analysis": string,
18
+ "references": [string, …]
19
+ }}
20
+ ]
21
+ }}
22
+
23
+ **Guidelines:**
24
+ - One entry per clue (visual and textual).
25
+ - Each object in the "evidence" list should explain a single textual or visual clue and be as many as possible. All image in the prompt follow the format: "image_{{idx:03d}}.jpg", starting from image_000.jpg.
26
+ - In the "references" list, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
27
+ - The "analysis" field must describe the clue and cite reference in its corresponding "references" using bracketed indices like [1], [2], etc. The corresponding URLs or images for those references must be included in the "references" list for that object.
28
+ + For contextual evidence, must cite textual/news URLs.
29
+ + For visual clues, cite `image_{{idx:03d}}.jpg` in `references` and any satellite/map URLs as needed.
30
+ - MUST use given links to support the analysis.
31
+ - If you can’t identify a specific building, give the city‑center coordinates.
32
+ """
33
+
34
+ LOCATION_PROMPT = """
35
+ Location: {location}
36
+
37
+ Your task is to determine the geographic coordinates (latitude and longitude) of the specified location by following these steps:
38
+
39
+ 1. Attempt to find the exact GPS coordinates using reliable online sources such as maps or satellite imagery.
40
+
41
+ 2. If the exact location is not available, find the coordinates of a nearby or adjacent place (e.g., a recognizable landmark, building, road, or intersection).
42
+
43
+ 3. If no specific nearby location can be found, use the coordinates of the broader area (e.g., the center of Khan Younis or Gaza).
44
+
45
+ 4. In the "references" list, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
46
+
47
+ Return your answer in the following JSON format:
48
+
49
+ {{
50
+ "latitude": float,
51
+ "longitude": float,
52
+ "analysis": "Describe how the coordinates were identified or approximated, including any visual or textual clues used.",
53
+ "references": ["URL1", "URL2", ...]
54
+ }}
55
+
56
+ - The "analysis" must clearly explain the reasoning behind the chosen coordinates.
57
+ - The "references" list must include all URLs cited in the analysis.
58
+ - Do not include any text outside of the JSON structure.
59
+ """
60
+
61
+ VERIFICATION_PROMPT = """
62
+ You are an expert in multimedia verification. Analyze the provided content and decide if it’s authentic or fabricated. Support your conclusion with detailed, verifiable evidence.
63
+
64
+ {prompt_data}
65
+
66
+ Prediction to verify:
67
+ {prediction}
68
+
69
+ Guidelines:
70
+ 1. Output only a JSON object with these fields:
71
+ {{
72
+ "latitude": float,
73
+ "longitude": float,
74
+ "location": string,
75
+ "evidence": [
76
+ {{
77
+ "analysis": string,
78
+ "references": [string, …]
79
+ }}
80
+ ]
81
+ }}
82
+
83
+ 2. Images are named “image_{{idx:03d}}.jpg”:
84
+ - Images up to “image_{satellite_image_id}.jpg” were used to generate the prediction.
85
+ - “image_{satellite_image_id}.jpg” is the satellite reference.
86
+ - Images after that show the claimed location’s landmarks—use them only to confirm buildings or landmarks.
87
+
88
+ 3. In the "references" field of response, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
89
+
90
+ 4. There must be both visual and contextual evidences. For each evidence entry:
91
+ a. **Visual evidence**: cross‑check the original images against the satellite view.
92
+ - When citing original images (those before `image_{satellite_image_id}.jpg`), **do not** list them alone: each must be accompanied by at least one supporting satellite image, street‑view photo, or map URL in the same reference list.
93
+ - If confirmed, **rewrite and enrich** your analysis with additional visual details (textures, angles, shadows) and cite any new image or map references.
94
+ - If it can’t be verified, **remove** that entry entirely.
95
+
96
+ b. **Contextual evidence**: verify against the provided URLs.
97
+ - If confirmed, **rewrite and expand** your analysis with deeper context (dates, sources, related events) and cite any new supporting links.
98
+ - If it can’t be verified, **remove** that entry.
99
+
100
+ c. Analyze but **do not** need cite transcript and metadata.
101
+
102
+ 5. All evidence must directly support the predicted latitude/longitude. Do not include analysis or references unrelated to verifying that specific location.
103
+
104
+ 6. Do **not** include any metadata (EXIF, timestamps, filenames) as evidence.
105
+
106
+ Return only the JSON—no extra text, markdown, or comments.
107
+ """
src/setup.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ base_path = Path(__file__).parent
7
+
8
+
9
+ def setup(
10
+ local_path: Path,
11
+ repo_id: str,
12
+ filename: str,
13
+ subfolder: str | None = None,
14
+ repo_type: str | None = None,
15
+ ) -> None:
16
+ if not local_path.exists():
17
+ local_path.parent.mkdir(parents=True, exist_ok=True)
18
+
19
+ cached_path = hf_hub_download(
20
+ repo_id=repo_id,
21
+ subfolder=subfolder,
22
+ filename=filename,
23
+ repo_type=repo_type,
24
+ )
25
+ shutil.copy(cached_path, local_path)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ checkpoint_path = (
30
+ base_path / "data/checkpoints/mercator_finetune_weight.pth"
31
+ ).resolve()
32
+ index_path = (base_path / "data/index/G3.index").resolve()
33
+ database_path = (base_path / "data/dataset/mp16/MP16_Pro_filtered.csv").resolve()
34
+
35
+ repo_id = "tduongvn/Checkpoints-ACMMM25"
36
+
37
+ setup(checkpoint_path, repo_id, "mercator_finetune_weight.pth")
38
+ setup(index_path, repo_id, "G3.index", "index")
39
+ setup(database_path, repo_id, "MP16_Pro_filtered.csv", "data/mp16")
src/utils.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
7
+
8
+ import numpy as np
9
+ import requests
10
+ import torch
11
+ import torch.nn as nn
12
+ from PIL import Image
13
+
14
+ # Set up logger
15
+ logger = logging.getLogger("uvicorn.error")
16
+
17
+ T = TypeVar("T")
18
+
19
+ NOMINATIM_URL = "https://nominatim.openstreetmap.org/search"
20
+ DEFAULT_USER_AGENT = "keyframe_extraction_app"
21
+
22
+
23
+ def get_gps_from_location(
24
+ location: str,
25
+ language: str = "en",
26
+ timeout: int = 10,
27
+ user_agent: str = DEFAULT_USER_AGENT,
28
+ ) -> Tuple[Optional[float], Optional[float]]:
29
+ """
30
+ Get GPS coordinates from a location string using Nominatim (OpenStreetMap).
31
+
32
+ Args:
33
+ location (str): Location string (e.g., city, address)
34
+ language (str): Language for results (default: 'en')
35
+ timeout (int): Request timeout in seconds (default: 10)
36
+ user_agent (str): User-Agent header (required by Nominatim)
37
+
38
+ Returns:
39
+ Tuple[Optional[float], Optional[float]]: (latitude, longitude), or (None, None) on failure
40
+ """
41
+ if not isinstance(location, str) or not location.strip():
42
+ logger.warning("Invalid or empty location string provided.")
43
+ return (None, None)
44
+
45
+ params = {
46
+ "q": location.strip(),
47
+ "format": "json",
48
+ "addressdetails": 1,
49
+ "accept-language": language,
50
+ "limit": 1,
51
+ }
52
+
53
+ headers = {
54
+ "User-Agent": user_agent,
55
+ }
56
+
57
+ try:
58
+ response = requests.get(
59
+ NOMINATIM_URL, params=params, headers=headers, timeout=timeout
60
+ )
61
+ response.raise_for_status()
62
+ data = response.json()
63
+
64
+ if not data:
65
+ logger.info(f"No results found for location: '{location}'")
66
+ return (None, None)
67
+
68
+ lat = float(data[0]["lat"])
69
+ lon = float(data[0]["lon"])
70
+ return (lat, lon)
71
+
72
+ except requests.RequestException as req_err:
73
+ logger.error(f"Request error while geocoding '{location}': {req_err}")
74
+ except (ValueError, KeyError, TypeError) as parse_err:
75
+ logger.error(
76
+ f"Failed to parse geocoding response for '{location}': {parse_err}"
77
+ )
78
+
79
+ return (None, None)
80
+
81
+
82
+ def calculate_similarity_scores(
83
+ model: nn.Module,
84
+ device: torch.device,
85
+ predicted_coords: List[Tuple[float, float]],
86
+ image_dir: Union[str, Path] = "images",
87
+ ) -> np.ndarray:
88
+ """
89
+ Calculate similarity scores between images and predicted coordinates.
90
+
91
+ Args:
92
+ rgb_images: List of PIL Images
93
+ predicted_coords: List of (lat, lon) tuples
94
+
95
+ Returns:
96
+ np.ndarray: Average similarity scores across all images for each coordinate
97
+ """
98
+ all_similarities = []
99
+ image_dir = Path(image_dir)
100
+
101
+ if not image_dir.exists():
102
+ raise ValueError(f"Image directory does not exist: {image_dir}")
103
+
104
+ for image_file in image_dir.glob("image_*.*"):
105
+ # Load image as PIL Image first
106
+ pil_image = Image.open(image_file).convert("RGB")
107
+
108
+ # Process the PIL image
109
+ image = model.vision_processor(images=pil_image, return_tensors="pt")[
110
+ "pixel_values"
111
+ ].reshape(-1, 224, 224)
112
+ image = image.unsqueeze(0).to(device)
113
+
114
+ with torch.no_grad():
115
+ vision_output = model.vision_model(image)[1]
116
+
117
+ image_embeds = model.vision_projection_else_2(
118
+ model.vision_projection(vision_output)
119
+ )
120
+ image_embeds = image_embeds / image_embeds.norm(
121
+ p=2, dim=-1, keepdim=True
122
+ ) # b, 768
123
+
124
+ # Process coordinates
125
+ gps_batch = torch.tensor(predicted_coords, dtype=torch.float32).to(device)
126
+ gps_input = gps_batch.clone().detach().unsqueeze(0) # Add batch dimension
127
+ b, c, _ = gps_input.shape
128
+ gps_input = gps_input.reshape(b * c, 2)
129
+ location_embeds = model.location_encoder(gps_input)
130
+ location_embeds = model.location_projection_else(
131
+ location_embeds.reshape(b * c, -1)
132
+ )
133
+ location_embeds = location_embeds / location_embeds.norm(
134
+ p=2, dim=-1, keepdim=True
135
+ )
136
+ location_embeds = location_embeds.reshape(b, c, -1) # b, c, 768
137
+
138
+ similarity = torch.matmul(
139
+ image_embeds.unsqueeze(1), location_embeds.permute(0, 2, 1)
140
+ ) # b, 1, c
141
+ similarity = similarity.squeeze(1).cpu().detach().numpy()
142
+ all_similarities.append(similarity[0]) # Remove batch dimension
143
+
144
+ # Calculate average similarity across all images
145
+ avg_similarities = np.mean(all_similarities, axis=0)
146
+ return avg_similarities
147
+
148
+
149
+ def is_retryable_error(error: Exception) -> bool:
150
+ """
151
+ Determines if the given exception is retryable based on known patterns
152
+ and exception types.
153
+
154
+ Args:
155
+ error (Exception): The exception to evaluate.
156
+
157
+ Returns:
158
+ bool: True if the error is considered retryable.
159
+ """
160
+ error_str = str(error).lower()
161
+
162
+ # Known substrings that indicate retryable errors
163
+ retryable_patterns = [
164
+ "503",
165
+ "500",
166
+ "502",
167
+ "504",
168
+ "overloaded",
169
+ "unavailable",
170
+ "internal",
171
+ "disconnected",
172
+ "connection",
173
+ "timeout",
174
+ "remoteprotocolerror",
175
+ "remote protocol error",
176
+ "network",
177
+ "socket",
178
+ "ssl",
179
+ "tls",
180
+ "rate limit",
181
+ "too many requests",
182
+ "429",
183
+ "service unavailable",
184
+ "temporarily unavailable",
185
+ ]
186
+
187
+ for pattern in retryable_patterns:
188
+ if pattern in error_str:
189
+ return True
190
+
191
+ # Retryable exception types
192
+ retryable_types = {
193
+ "connectionerror",
194
+ "timeout",
195
+ "httperror",
196
+ "remoteclosederror",
197
+ "remoteprotocolerror",
198
+ "sslerror",
199
+ "tlserror",
200
+ "valueerror",
201
+ }
202
+
203
+ error_type = type(error).__name__.lower()
204
+ return error_type in retryable_types
205
+
206
+
207
+ async def handle_async_api_call_with_retry(
208
+ api_call_func: Callable[[], Any],
209
+ max_retries: int = 10,
210
+ base_delay: float = 2.0,
211
+ fallback_result: Optional[T] = None,
212
+ error_context: str = "API call",
213
+ ) -> T:
214
+ """
215
+ Executes an asynchronous API call with retry logic and exponential backoff.
216
+
217
+ Args:
218
+ api_call_func (Callable): An async function that returns any type (T).
219
+ max_retries (int): Maximum retry attempts.
220
+ base_delay (float): Initial delay for backoff (doubles each retry).
221
+ fallback_result (Optional[T]): Optional result to return on failure.
222
+ error_context (str): Contextual info for logging.
223
+
224
+ Returns:
225
+ T: Result from the API call or fallback.
226
+ """
227
+ for attempt in range(1, max_retries + 1):
228
+ try:
229
+ result = await api_call_func()
230
+ return result
231
+
232
+ except Exception as error:
233
+ is_last_attempt = attempt == max_retries
234
+ retryable = is_retryable_error(error)
235
+
236
+ logger.warning(
237
+ f"{error_context} failed (attempt {attempt}/{max_retries}): {error}"
238
+ )
239
+
240
+ if retryable and not is_last_attempt:
241
+ delay = base_delay * (2 ** (attempt - 1))
242
+ logger.info(f"Retrying in {delay:.1f}s...")
243
+ await asyncio.sleep(delay)
244
+ continue
245
+
246
+ if not retryable:
247
+ logger.error(f"Non-retryable error encountered: {error}")
248
+ elif is_last_attempt:
249
+ logger.error(f"Max retries reached for {error_context}. Giving up.")
250
+
251
+ break
252
+
253
+ if fallback_result is not None:
254
+ logger.warning(f"Returning fallback result for {error_context}")
255
+ return fallback_result
256
+
257
+ logger.error(f"No fallback result provided for {error_context}.")
258
+ raise RuntimeError(f"{error_context} failed with no result.")
259
+
260
+
261
+ def extract_and_parse_json(raw_text: str) -> Dict[str, Any]:
262
+ """
263
+ Extract and parse the first JSON object found in raw_text.
264
+ Only returns a dict; falls back to {} on failure or if parsed value isn't a dict.
265
+
266
+ Args:
267
+ raw_text (str): Raw text (e.g., from an LLM response)
268
+
269
+ Returns:
270
+ Dict[str, Any]: Parsed JSON dict, or {} if none valid is found.
271
+ """
272
+ start = raw_text.find("{")
273
+ end = raw_text.rfind("}")
274
+
275
+ if start == -1 or end == -1 or end <= start:
276
+ logger.error("⚠️ No JSON object found. Snippet:", raw_text[:200])
277
+ return {}
278
+
279
+ snippet = raw_text[start : end + 1]
280
+
281
+ try:
282
+ parsed = json.loads(snippet)
283
+ if isinstance(parsed, dict):
284
+ return parsed
285
+ logger.error("⚠️ JSON parsed but not a dict—got type:", type(parsed).__name__)
286
+ except json.JSONDecodeError as e:
287
+ logger.error("⚠️ JSON decoding error:", e)
288
+
289
+ return {}
290
+
291
+
292
+ def image_to_base64(image_path: Path) -> str:
293
+ if not image_path.is_file():
294
+ logger.error(f"No such image: {image_path}")
295
+ return ""
296
+ data = image_path.read_bytes()
297
+ return base64.b64encode(data).decode("utf-8")
298
+
299
+
300
+ def load_images_as_base64() -> Optional[list[str]]:
301
+ img_dir = Path(__file__).parent / "data" / "prompt_data" / "images"
302
+
303
+ if not img_dir.exists() or not any(img_dir.iterdir()):
304
+ return None
305
+
306
+ base64_images: list[str] = []
307
+ for file in img_dir.iterdir():
308
+ if file.is_file() and file.suffix.lower() in [".png", ".jpg", ".jpeg", ".gif"]:
309
+ with open(file, "rb") as f:
310
+ encoded = base64.b64encode(f.read()).decode("utf-8")
311
+ base64_images.append(encoded)
312
+ return base64_images if base64_images else None