Spaces:
Sleeping
Sleeping
Commit ·
eff2be4
0
Parent(s):
init prj
Browse files- .gitignore +18 -0
- Dockerfile +19 -0
- LICENSE +201 -0
- README.md +166 -0
- app.py +141 -0
- docker-compose.yml +19 -0
- entrypoint.sh +24 -0
- openapi.json +232 -0
- requirements.txt +434 -0
- src/data_processor.py +488 -0
- src/g3/G3.py +134 -0
- src/g3/dataset.py +407 -0
- src/g3/hparams.yaml +57 -0
- src/g3/locationencoder.py +133 -0
- src/g3/nn/mlp.py +20 -0
- src/g3/nn/rff_mlp.py +38 -0
- src/g3/nn/siren.py +100 -0
- src/g3/pe/projection.py +54 -0
- src/g3/pe/projection_rff.py +65 -0
- src/g3/pe/spherical_harmonics.py +40 -0
- src/g3/pe/spherical_harmonics_closed_form.py +40 -0
- src/g3/pe/spherical_harmonics_generate_ylms.py +73 -0
- src/g3/pe/spherical_harmonics_ylm.py +0 -0
- src/g3/rff/functional.py +77 -0
- src/g3/rff/layers.py +86 -0
- src/g3_batch_prediction.py +568 -0
- src/prompt/__init__.py +6 -0
- src/prompt/factory.py +418 -0
- src/prompt/fetch/content_fetch.py +171 -0
- src/prompt/fetch/satellite_fetch.py +87 -0
- src/prompt/preprocess/keyframe_extract.py +195 -0
- src/prompt/preprocess/video_transcribe.py +132 -0
- src/prompt/search/image_search.py +527 -0
- src/prompt/search/index_search.py +271 -0
- src/prompt/search/text_search.py +271 -0
- src/prompt/template.py +107 -0
- src/setup.py +39 -0
- src/utils.py +312 -0
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Large files, to be downloaded via huggingface.
|
| 2 |
+
g3/index/G3.index
|
| 3 |
+
g3/checkpoints/mercator_finetune_weight.pth
|
| 4 |
+
g3/data/mp16/MP16_Pro_filtered.csv
|
| 5 |
+
index
|
| 6 |
+
checkpoints
|
| 7 |
+
data
|
| 8 |
+
|
| 9 |
+
# venv and dev stuff
|
| 10 |
+
linuxenv
|
| 11 |
+
myenv
|
| 12 |
+
.venv
|
| 13 |
+
.env
|
| 14 |
+
acmmm2025-grand-challenge-gg-credentials.json
|
| 15 |
+
cred.json
|
| 16 |
+
**/__pycache__/
|
| 17 |
+
pyproject.toml
|
| 18 |
+
uv.lock
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim-bullseye
|
| 2 |
+
WORKDIR /code
|
| 3 |
+
|
| 4 |
+
RUN apt-get update && apt-get install -y ffmpeg xvfb
|
| 5 |
+
|
| 6 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 8 |
+
RUN playwright install chrome
|
| 9 |
+
RUN playwright install-deps
|
| 10 |
+
RUN playwright install
|
| 11 |
+
|
| 12 |
+
COPY ./src /code/src
|
| 13 |
+
RUN python /code/src/setup.py
|
| 14 |
+
|
| 15 |
+
COPY ./app.py /code/app.py
|
| 16 |
+
COPY ./entrypoint.sh /code/entrypoint.sh
|
| 17 |
+
|
| 18 |
+
RUN chmod +x /code/entrypoint.sh
|
| 19 |
+
ENTRYPOINT [ "/code/entrypoint.sh" ]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# G3 Geolocation Service
|
| 2 |
+
|
| 3 |
+
This is a containerized geolocation service based on the paper "G3: An Effective and Adaptive Framework for Worldwide Geolocalization Using Large Multi-Modality Models". The service is augmented with multilayer verification for location and evidence.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Docker with GPU support
|
| 8 |
+
- NVIDIA Container Toolkit (for GPU access)
|
| 9 |
+
- Required API keys (see Environment Variables section)
|
| 10 |
+
|
| 11 |
+
## Quick Start
|
| 12 |
+
|
| 13 |
+
### 1. Prepare Environment File
|
| 14 |
+
|
| 15 |
+
Create a `.env` file with the following variables:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
GOOGLE_CLOUD_API_KEY=your_google_cloud_api_key
|
| 19 |
+
GOOGLE_CSE_CX=your_google_custom_search_engine_id
|
| 20 |
+
SCRAPINGDOG_API_KEY=your_scrapingdog_api_key
|
| 21 |
+
IMGBB_API_KEY=your_imgbb_api_key
|
| 22 |
+
GOOGLE_APPLICATION_CREDENTIALS=/code/path/to/your/credentials.json
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### 2. Prepare Google Cloud Credentials
|
| 26 |
+
|
| 27 |
+
Ensure you have a Google Cloud service account JSON credentials file ready for copying to the container.
|
| 28 |
+
|
| 29 |
+
### 3. Build Docker Image
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
docker build -t g3-geolocation .
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 4. Create Docker Container
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
docker create --name g3-container -p 80:80 --gpus=all --env-file .env g3-geolocation
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 5. Copy Credentials to Container
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
docker cp /path/to/your/credentials.json g3-container:/code/
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 6. Start Container
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
docker start g3-container
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
|
| 55 |
+
Once the container is running, the service will be available at `http://localhost:80`.
|
| 56 |
+
|
| 57 |
+
### API Endpoints
|
| 58 |
+
|
| 59 |
+
- **POST** `/g3/predict` - Submit images/videos for geolocation prediction
|
| 60 |
+
- **GET** `/g3/openapi` - Get OpenAPI specification
|
| 61 |
+
|
| 62 |
+
### Example Request
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
curl -X POST "http://localhost:80/g3/predict" \
|
| 66 |
+
-H "Content-Type: multipart/form-data" \
|
| 67 |
+
-F "files=@your_image.jpg"
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Environment Variables
|
| 71 |
+
|
| 72 |
+
| Variable | Description | Required |
|
| 73 |
+
| -------------------------------- | ------------------------------------------ | -------- |
|
| 74 |
+
| `GOOGLE_CLOUD_API_KEY` | Google Cloud API key for Gemini and Custom Google Search API | Yes |
|
| 75 |
+
| `GOOGLE_CSE_CX` | Google Custom Search Engine ID | Yes |
|
| 76 |
+
| `SCRAPINGDOG_API_KEY` | ScrapingDog API key for web scraping | Yes |
|
| 77 |
+
| `IMGBB_API_KEY` | ImgBB API key for image hosting | Yes |
|
| 78 |
+
| `GOOGLE_APPLICATION_CREDENTIALS` | Path to Google Cloud credentials JSON file | Yes |
|
| 79 |
+
|
| 80 |
+
## API Keys Setup
|
| 81 |
+
|
| 82 |
+
### Google Cloud API Key
|
| 83 |
+
|
| 84 |
+
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
| 85 |
+
2. Enable Gemini API and Vision API
|
| 86 |
+
3. Create an API key in the Credentials section
|
| 87 |
+
|
| 88 |
+
### Google Custom Search Engine
|
| 89 |
+
|
| 90 |
+
1. Go to [Google Custom Search](https://cse.google.com/)
|
| 91 |
+
2. Create a new search engine
|
| 92 |
+
3. Copy the Search Engine ID (CX)
|
| 93 |
+
|
| 94 |
+
### ScrapingDog API Key
|
| 95 |
+
|
| 96 |
+
1. Sign up at [ScrapingDog](https://scrapingdog.com/)
|
| 97 |
+
2. Get your API key from the dashboard
|
| 98 |
+
|
| 99 |
+
### ImgBB API Key
|
| 100 |
+
|
| 101 |
+
1. Sign up at [ImgBB](https://imgbb.com/)
|
| 102 |
+
2. Get your API key from the API section
|
| 103 |
+
|
| 104 |
+
## Container Management
|
| 105 |
+
|
| 106 |
+
### View Logs
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
docker logs g3-container
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Stop Container
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
docker stop g3-container
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Remove Container
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
docker rm g3-container
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Remove Image
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
docker rmi g3-geolocation
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## Troubleshooting
|
| 131 |
+
|
| 132 |
+
### GPU Access Issues
|
| 133 |
+
|
| 134 |
+
Ensure NVIDIA Container Toolkit is properly installed:
|
| 135 |
+
|
| 136 |
+
```bash
|
| 137 |
+
nvidia-smi
|
| 138 |
+
docker run --rm --gpus all nvidia/cuda:11.0-base-ubuntu20.04 nvidia-smi
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### API Key Issues
|
| 142 |
+
|
| 143 |
+
- Verify all API keys are valid and have proper permissions
|
| 144 |
+
- Check that the credentials file is properly copied to the container
|
| 145 |
+
- Ensure the `GOOGLE_APPLICATION_CREDENTIALS` path matches the copied file location
|
| 146 |
+
|
| 147 |
+
### Memory Issues
|
| 148 |
+
|
| 149 |
+
If you encounter out-of-memory errors, consider:
|
| 150 |
+
|
| 151 |
+
- Reducing image sizes before upload
|
| 152 |
+
- Using a machine with more RAM/VRAM
|
| 153 |
+
- Adjusting batch processing parameters
|
| 154 |
+
|
| 155 |
+
## Citation
|
| 156 |
+
|
| 157 |
+
```bib
|
| 158 |
+
@article{jia2024g3,
|
| 159 |
+
title={G3: an effective and adaptive framework for worldwide geolocalization using large multi-modality models},
|
| 160 |
+
author={Jia, Pengyue and Liu, Yiding and Li, Xiaopeng and Zhao, Xiangyu and Wang, Yuhao and Du, Yantong and Han, Xiao and Wei, Xuetao and Wang, Shuaiqiang and Yin, Dawei},
|
| 161 |
+
journal={Advances in Neural Information Processing Systems},
|
| 162 |
+
volume={37},
|
| 163 |
+
pages={53198--53221},
|
| 164 |
+
year={2024}
|
| 165 |
+
}
|
| 166 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import uuid
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import Annotated, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from fastapi import FastAPI, File, HTTPException, UploadFile, status
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
from src.g3_batch_prediction import G3BatchPredictor
|
| 14 |
+
|
| 15 |
+
from src.utils import load_images_as_base64
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EvidenceResponse(BaseModel):
|
| 19 |
+
analysis: Annotated[
|
| 20 |
+
str,
|
| 21 |
+
Field(description="A supporting analysis for the prediction."),
|
| 22 |
+
]
|
| 23 |
+
references: Annotated[
|
| 24 |
+
list[str],
|
| 25 |
+
Field(description="Links or base64-encoded JPEG supporting the analysis."),
|
| 26 |
+
] = []
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LocationPredictionResponse(BaseModel):
|
| 30 |
+
latitude: Annotated[
|
| 31 |
+
float,
|
| 32 |
+
Field(description="Latitude of the predicted location, in degree."),
|
| 33 |
+
]
|
| 34 |
+
longitude: Annotated[
|
| 35 |
+
float,
|
| 36 |
+
Field(description="Longitude of the predicted location, in degree."),
|
| 37 |
+
]
|
| 38 |
+
location: Annotated[
|
| 39 |
+
str,
|
| 40 |
+
Field(description="Textual description of the predicted location."),
|
| 41 |
+
]
|
| 42 |
+
evidence: Annotated[
|
| 43 |
+
list[EvidenceResponse],
|
| 44 |
+
Field(description="List of supporting analyses for the prediction."),
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class PredictionResponse(BaseModel):
|
| 49 |
+
prediction: Annotated[
|
| 50 |
+
LocationPredictionResponse,
|
| 51 |
+
Field(description="The location prediction and accompanying analysis."),
|
| 52 |
+
]
|
| 53 |
+
transcript: Annotated[
|
| 54 |
+
str | None,
|
| 55 |
+
Field(description="The extracted and concatenated transcripts, if any."),
|
| 56 |
+
] = None
|
| 57 |
+
media: Optional[list[str]] = Field(
|
| 58 |
+
default=None,
|
| 59 |
+
description="List of media files processed during prediction."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
predictor: G3BatchPredictor
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@asynccontextmanager
|
| 67 |
+
async def lifespan(app: FastAPI):
|
| 68 |
+
load_dotenv()
|
| 69 |
+
|
| 70 |
+
with open("openapi.json", "wt") as api_file:
|
| 71 |
+
json.dump(app.openapi(), api_file, indent=4)
|
| 72 |
+
|
| 73 |
+
global predictor
|
| 74 |
+
predictor = G3BatchPredictor(device="cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
+
|
| 76 |
+
yield
|
| 77 |
+
|
| 78 |
+
del predictor
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
app = FastAPI(
|
| 82 |
+
lifespan=lifespan,
|
| 83 |
+
title="G3",
|
| 84 |
+
description="An endpoint to predict GPS coordinate from static image,"
|
| 85 |
+
" using G3 Framework.",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.post(
|
| 90 |
+
"/g3/predict",
|
| 91 |
+
description="Provide location prediction.",
|
| 92 |
+
)
|
| 93 |
+
async def predict_endpoint(
|
| 94 |
+
files: Annotated[
|
| 95 |
+
list[UploadFile],
|
| 96 |
+
File(description="Input images, videos and metadata json."),
|
| 97 |
+
],
|
| 98 |
+
) -> PredictionResponse:
|
| 99 |
+
# Write files to disk
|
| 100 |
+
try:
|
| 101 |
+
predictor.clear_directories()
|
| 102 |
+
for file in files:
|
| 103 |
+
filename = file.filename if file.filename is not None else uuid.uuid4().hex
|
| 104 |
+
filepath = predictor.input_dir / filename
|
| 105 |
+
os.makedirs(predictor.input_dir, exist_ok=True)
|
| 106 |
+
with open(filepath, "wb") as buffer:
|
| 107 |
+
shutil.copyfileobj(file.file, buffer)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
raise HTTPException(
|
| 110 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 111 |
+
detail=f"Failed to save file: {e}",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Get prediction
|
| 115 |
+
response = await predictor.predict(model_name="gemini-2.5-pro")
|
| 116 |
+
# response = predictor.get_response(response)
|
| 117 |
+
prediction = LocationPredictionResponse(
|
| 118 |
+
latitude=response.latitude,
|
| 119 |
+
longitude=response.longitude,
|
| 120 |
+
location=response.location,
|
| 121 |
+
evidence=[
|
| 122 |
+
EvidenceResponse(analysis=ev.analysis, references=ev.references)
|
| 123 |
+
for ev in response.evidence
|
| 124 |
+
],
|
| 125 |
+
)
|
| 126 |
+
# Get transcript if available
|
| 127 |
+
transcript = predictor.get_transcript()
|
| 128 |
+
|
| 129 |
+
# Get media files if available
|
| 130 |
+
images_b64 = load_images_as_base64()
|
| 131 |
+
|
| 132 |
+
# Clear directories
|
| 133 |
+
return PredictionResponse(prediction=prediction, transcript=transcript, media=images_b64)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@app.get(
|
| 137 |
+
"/g3/openapi",
|
| 138 |
+
description="Provide the OpenAPI JSON describing this service's endpoints.",
|
| 139 |
+
)
|
| 140 |
+
async def openapi():
|
| 141 |
+
return app.openapi()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
web:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "8000:80"
|
| 6 |
+
environment:
|
| 7 |
+
- GOOGLE_APPLICATION_CREDENTIALS=/code/keys/credentials.json
|
| 8 |
+
volumes:
|
| 9 |
+
- ./.env:/code/.env
|
| 10 |
+
|
| 11 |
+
- ./keys:/code/keys
|
| 12 |
+
|
| 13 |
+
- ./entrypoint.sh:/code/entrypoint.sh
|
| 14 |
+
|
| 15 |
+
env_file:
|
| 16 |
+
- ./.env
|
| 17 |
+
|
| 18 |
+
restart: unless-stopped
|
| 19 |
+
runtime: nvidia
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# --- cleanup any stale Xvfb lock/socket ---
|
| 5 |
+
if [ -e /tmp/.X99-lock ]; then
|
| 6 |
+
echo "[entrypoint] removing stale /tmp/.X99-lock" >&2
|
| 7 |
+
rm -f /tmp/.X99-lock
|
| 8 |
+
fi
|
| 9 |
+
if [ -e /tmp/.X11-unix/X99 ]; then
|
| 10 |
+
echo "[entrypoint] removing stale /tmp/.X11-unix/X99" >&2
|
| 11 |
+
rm -f /tmp/.X11-unix/X99
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
# --- start the virtual display ---
|
| 15 |
+
echo "[entrypoint] starting Xvfb on :99" >&2
|
| 16 |
+
Xvfb :99 -screen 0 1920x1080x24 &
|
| 17 |
+
|
| 18 |
+
# --- point GUI apps at it ---
|
| 19 |
+
export DISPLAY=:99
|
| 20 |
+
echo "[entrypoint] DISPLAY set to $DISPLAY" >&2
|
| 21 |
+
|
| 22 |
+
# --- launch FastAPI ---
|
| 23 |
+
echo "[entrypoint] exec fastapi" >&2
|
| 24 |
+
exec fastapi run app.py --port 80
|
openapi.json
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"openapi": "3.1.0",
|
| 3 |
+
"info": {
|
| 4 |
+
"title": "G3",
|
| 5 |
+
"description": "An endpoint to predict GPS coordinate from static image, using G3 Framework.",
|
| 6 |
+
"version": "0.1.0"
|
| 7 |
+
},
|
| 8 |
+
"paths": {
|
| 9 |
+
"/g3/predict": {
|
| 10 |
+
"post": {
|
| 11 |
+
"summary": "Predict Endpoint",
|
| 12 |
+
"description": "Provide location prediction.",
|
| 13 |
+
"operationId": "predict_endpoint_g3_predict_post",
|
| 14 |
+
"requestBody": {
|
| 15 |
+
"content": {
|
| 16 |
+
"multipart/form-data": {
|
| 17 |
+
"schema": {
|
| 18 |
+
"$ref": "#/components/schemas/Body_predict_endpoint_g3_predict_post"
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"required": true
|
| 23 |
+
},
|
| 24 |
+
"responses": {
|
| 25 |
+
"200": {
|
| 26 |
+
"description": "Successful Response",
|
| 27 |
+
"content": {
|
| 28 |
+
"application/json": {
|
| 29 |
+
"schema": {
|
| 30 |
+
"$ref": "#/components/schemas/PredictionResponse"
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
"422": {
|
| 36 |
+
"description": "Validation Error",
|
| 37 |
+
"content": {
|
| 38 |
+
"application/json": {
|
| 39 |
+
"schema": {
|
| 40 |
+
"$ref": "#/components/schemas/HTTPValidationError"
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
},
|
| 48 |
+
"/g3/openapi": {
|
| 49 |
+
"get": {
|
| 50 |
+
"summary": "Openapi",
|
| 51 |
+
"description": "Provide the OpenAPI JSON describing this service's endpoints.",
|
| 52 |
+
"operationId": "openapi_g3_openapi_get",
|
| 53 |
+
"responses": {
|
| 54 |
+
"200": {
|
| 55 |
+
"description": "Successful Response",
|
| 56 |
+
"content": {
|
| 57 |
+
"application/json": {
|
| 58 |
+
"schema": {}
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"components": {
|
| 67 |
+
"schemas": {
|
| 68 |
+
"Body_predict_endpoint_g3_predict_post": {
|
| 69 |
+
"properties": {
|
| 70 |
+
"files": {
|
| 71 |
+
"items": {
|
| 72 |
+
"type": "string",
|
| 73 |
+
"format": "binary"
|
| 74 |
+
},
|
| 75 |
+
"type": "array",
|
| 76 |
+
"title": "Files",
|
| 77 |
+
"description": "Input images, videos and metadata json."
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"type": "object",
|
| 81 |
+
"required": [
|
| 82 |
+
"files"
|
| 83 |
+
],
|
| 84 |
+
"title": "Body_predict_endpoint_g3_predict_post"
|
| 85 |
+
},
|
| 86 |
+
"EvidenceResponse": {
|
| 87 |
+
"properties": {
|
| 88 |
+
"analysis": {
|
| 89 |
+
"type": "string",
|
| 90 |
+
"title": "Analysis",
|
| 91 |
+
"description": "A supporting analysis for the prediction."
|
| 92 |
+
},
|
| 93 |
+
"references": {
|
| 94 |
+
"items": {
|
| 95 |
+
"type": "string"
|
| 96 |
+
},
|
| 97 |
+
"type": "array",
|
| 98 |
+
"title": "References",
|
| 99 |
+
"description": "Links or base64-encoded JPEG supporting the analysis.",
|
| 100 |
+
"default": []
|
| 101 |
+
}
|
| 102 |
+
},
|
| 103 |
+
"type": "object",
|
| 104 |
+
"required": [
|
| 105 |
+
"analysis"
|
| 106 |
+
],
|
| 107 |
+
"title": "EvidenceResponse"
|
| 108 |
+
},
|
| 109 |
+
"HTTPValidationError": {
|
| 110 |
+
"properties": {
|
| 111 |
+
"detail": {
|
| 112 |
+
"items": {
|
| 113 |
+
"$ref": "#/components/schemas/ValidationError"
|
| 114 |
+
},
|
| 115 |
+
"type": "array",
|
| 116 |
+
"title": "Detail"
|
| 117 |
+
}
|
| 118 |
+
},
|
| 119 |
+
"type": "object",
|
| 120 |
+
"title": "HTTPValidationError"
|
| 121 |
+
},
|
| 122 |
+
"LocationPredictionResponse": {
|
| 123 |
+
"properties": {
|
| 124 |
+
"latitude": {
|
| 125 |
+
"type": "number",
|
| 126 |
+
"title": "Latitude",
|
| 127 |
+
"description": "Latitude of the predicted location, in degree."
|
| 128 |
+
},
|
| 129 |
+
"longitude": {
|
| 130 |
+
"type": "number",
|
| 131 |
+
"title": "Longitude",
|
| 132 |
+
"description": "Longitude of the predicted location, in degree."
|
| 133 |
+
},
|
| 134 |
+
"location": {
|
| 135 |
+
"type": "string",
|
| 136 |
+
"title": "Location",
|
| 137 |
+
"description": "Textual description of the predicted location."
|
| 138 |
+
},
|
| 139 |
+
"evidence": {
|
| 140 |
+
"items": {
|
| 141 |
+
"$ref": "#/components/schemas/EvidenceResponse"
|
| 142 |
+
},
|
| 143 |
+
"type": "array",
|
| 144 |
+
"title": "Evidence",
|
| 145 |
+
"description": "List of supporting analyses for the prediction."
|
| 146 |
+
}
|
| 147 |
+
},
|
| 148 |
+
"type": "object",
|
| 149 |
+
"required": [
|
| 150 |
+
"latitude",
|
| 151 |
+
"longitude",
|
| 152 |
+
"location",
|
| 153 |
+
"evidence"
|
| 154 |
+
],
|
| 155 |
+
"title": "LocationPredictionResponse"
|
| 156 |
+
},
|
| 157 |
+
"PredictionResponse": {
|
| 158 |
+
"properties": {
|
| 159 |
+
"prediction": {
|
| 160 |
+
"$ref": "#/components/schemas/LocationPredictionResponse",
|
| 161 |
+
"description": "The location prediction and accompanying analysis."
|
| 162 |
+
},
|
| 163 |
+
"transcript": {
|
| 164 |
+
"anyOf": [
|
| 165 |
+
{
|
| 166 |
+
"type": "string"
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"type": "null"
|
| 170 |
+
}
|
| 171 |
+
],
|
| 172 |
+
"title": "Transcript",
|
| 173 |
+
"description": "The extracted and concatenated transcripts, if any."
|
| 174 |
+
},
|
| 175 |
+
"media": {
|
| 176 |
+
"anyOf": [
|
| 177 |
+
{
|
| 178 |
+
"items": {
|
| 179 |
+
"type": "string"
|
| 180 |
+
},
|
| 181 |
+
"type": "array"
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"type": "null"
|
| 185 |
+
}
|
| 186 |
+
],
|
| 187 |
+
"title": "Media",
|
| 188 |
+
"description": "List of media files processed during prediction."
|
| 189 |
+
}
|
| 190 |
+
},
|
| 191 |
+
"type": "object",
|
| 192 |
+
"required": [
|
| 193 |
+
"prediction"
|
| 194 |
+
],
|
| 195 |
+
"title": "PredictionResponse"
|
| 196 |
+
},
|
| 197 |
+
"ValidationError": {
|
| 198 |
+
"properties": {
|
| 199 |
+
"loc": {
|
| 200 |
+
"items": {
|
| 201 |
+
"anyOf": [
|
| 202 |
+
{
|
| 203 |
+
"type": "string"
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"type": "integer"
|
| 207 |
+
}
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
"type": "array",
|
| 211 |
+
"title": "Location"
|
| 212 |
+
},
|
| 213 |
+
"msg": {
|
| 214 |
+
"type": "string",
|
| 215 |
+
"title": "Message"
|
| 216 |
+
},
|
| 217 |
+
"type": {
|
| 218 |
+
"type": "string",
|
| 219 |
+
"title": "Error Type"
|
| 220 |
+
}
|
| 221 |
+
},
|
| 222 |
+
"type": "object",
|
| 223 |
+
"required": [
|
| 224 |
+
"loc",
|
| 225 |
+
"msg",
|
| 226 |
+
"type"
|
| 227 |
+
],
|
| 228 |
+
"title": "ValidationError"
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv export --no-hashes --format requirements-txt
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
# via pydantic
|
| 5 |
+
anyio==4.9.0
|
| 6 |
+
# via
|
| 7 |
+
# google-genai
|
| 8 |
+
# httpx
|
| 9 |
+
# starlette
|
| 10 |
+
# watchfiles
|
| 11 |
+
cachetools==5.5.2
|
| 12 |
+
# via google-auth
|
| 13 |
+
certifi==2025.7.14
|
| 14 |
+
# via
|
| 15 |
+
# httpcore
|
| 16 |
+
# httpx
|
| 17 |
+
# pyproj
|
| 18 |
+
# requests
|
| 19 |
+
# sentry-sdk
|
| 20 |
+
charset-normalizer==3.4.2
|
| 21 |
+
# via requests
|
| 22 |
+
click==8.2.1
|
| 23 |
+
# via
|
| 24 |
+
# rich-toolkit
|
| 25 |
+
# typer
|
| 26 |
+
# uvicorn
|
| 27 |
+
colorama==0.4.6 ; sys_platform == 'win32'
|
| 28 |
+
# via
|
| 29 |
+
# click
|
| 30 |
+
# tqdm
|
| 31 |
+
# uvicorn
|
| 32 |
+
decorator==5.2.1
|
| 33 |
+
# via moviepy
|
| 34 |
+
dnspython==2.7.0
|
| 35 |
+
# via email-validator
|
| 36 |
+
einops==0.8.1
|
| 37 |
+
# via acmmm25-grand-challenge-geolocation
|
| 38 |
+
email-validator==2.2.0
|
| 39 |
+
# via
|
| 40 |
+
# fastapi
|
| 41 |
+
# pydantic
|
| 42 |
+
faiss-gpu-cu12==1.11.0
|
| 43 |
+
# via acmmm25-grand-challenge-geolocation
|
| 44 |
+
fastapi==0.116.1
|
| 45 |
+
# via acmmm25-grand-challenge-geolocation
|
| 46 |
+
fastapi-cli==0.0.8
|
| 47 |
+
# via fastapi
|
| 48 |
+
fastapi-cloud-cli==0.1.4
|
| 49 |
+
# via fastapi-cli
|
| 50 |
+
ffmpy==0.6.0
|
| 51 |
+
# via katna
|
| 52 |
+
filelock==3.18.0
|
| 53 |
+
# via
|
| 54 |
+
# huggingface-hub
|
| 55 |
+
# torch
|
| 56 |
+
# transformers
|
| 57 |
+
fsspec==2025.7.0
|
| 58 |
+
# via
|
| 59 |
+
# huggingface-hub
|
| 60 |
+
# torch
|
| 61 |
+
ftfy==6.3.1
|
| 62 |
+
# via open-clip-torch
|
| 63 |
+
geographiclib==2.0
|
| 64 |
+
# via geopy
|
| 65 |
+
geopy==2.4.1
|
| 66 |
+
# via acmmm25-grand-challenge-geolocation
|
| 67 |
+
google-api-core==2.25.1
|
| 68 |
+
# via
|
| 69 |
+
# google-cloud-videointelligence
|
| 70 |
+
# google-cloud-vision
|
| 71 |
+
google-auth==2.40.3
|
| 72 |
+
# via
|
| 73 |
+
# google-api-core
|
| 74 |
+
# google-cloud-videointelligence
|
| 75 |
+
# google-cloud-vision
|
| 76 |
+
# google-genai
|
| 77 |
+
google-cloud-videointelligence==2.16.2
|
| 78 |
+
# via acmmm25-grand-challenge-geolocation
|
| 79 |
+
google-cloud-vision==3.10.2
|
| 80 |
+
# via acmmm25-grand-challenge-geolocation
|
| 81 |
+
google-genai==1.26.0
|
| 82 |
+
# via acmmm25-grand-challenge-geolocation
|
| 83 |
+
googleapis-common-protos==1.70.0
|
| 84 |
+
# via
|
| 85 |
+
# google-api-core
|
| 86 |
+
# grpcio-status
|
| 87 |
+
greenlet==3.2.3
|
| 88 |
+
# via playwright
|
| 89 |
+
grpcio==1.73.1
|
| 90 |
+
# via
|
| 91 |
+
# google-api-core
|
| 92 |
+
# grpcio-status
|
| 93 |
+
grpcio-status==1.73.1
|
| 94 |
+
# via google-api-core
|
| 95 |
+
h11==0.16.0
|
| 96 |
+
# via
|
| 97 |
+
# httpcore
|
| 98 |
+
# uvicorn
|
| 99 |
+
hf-xet==1.1.5 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 100 |
+
# via huggingface-hub
|
| 101 |
+
httpcore==1.0.9
|
| 102 |
+
# via httpx
|
| 103 |
+
httptools==0.6.4
|
| 104 |
+
# via uvicorn
|
| 105 |
+
httpx==0.28.1
|
| 106 |
+
# via
|
| 107 |
+
# fastapi
|
| 108 |
+
# fastapi-cloud-cli
|
| 109 |
+
# google-genai
|
| 110 |
+
huggingface-hub==0.33.4
|
| 111 |
+
# via
|
| 112 |
+
# open-clip-torch
|
| 113 |
+
# timm
|
| 114 |
+
# tokenizers
|
| 115 |
+
# transformers
|
| 116 |
+
idna==3.10
|
| 117 |
+
# via
|
| 118 |
+
# anyio
|
| 119 |
+
# email-validator
|
| 120 |
+
# httpx
|
| 121 |
+
# requests
|
| 122 |
+
imageio==2.37.0
|
| 123 |
+
# via
|
| 124 |
+
# moviepy
|
| 125 |
+
# scikit-image
|
| 126 |
+
imageio-ffmpeg==0.6.0
|
| 127 |
+
# via
|
| 128 |
+
# katna
|
| 129 |
+
# moviepy
|
| 130 |
+
imutils==0.5.4
|
| 131 |
+
# via katna
|
| 132 |
+
jinja2==3.1.6
|
| 133 |
+
# via
|
| 134 |
+
# fastapi
|
| 135 |
+
# torch
|
| 136 |
+
joblib==1.5.1
|
| 137 |
+
# via scikit-learn
|
| 138 |
+
katna==0.9.2
|
| 139 |
+
# via acmmm25-grand-challenge-geolocation
|
| 140 |
+
lazy-loader==0.4
|
| 141 |
+
# via scikit-image
|
| 142 |
+
llvmlite==0.44.0
|
| 143 |
+
# via numba
|
| 144 |
+
markdown-it-py==3.0.0
|
| 145 |
+
# via rich
|
| 146 |
+
markupsafe==3.0.2
|
| 147 |
+
# via jinja2
|
| 148 |
+
mdurl==0.1.2
|
| 149 |
+
# via markdown-it-py
|
| 150 |
+
more-itertools==10.7.0
|
| 151 |
+
# via openai-whisper
|
| 152 |
+
moviepy==2.2.1
|
| 153 |
+
# via acmmm25-grand-challenge-geolocation
|
| 154 |
+
mpmath==1.3.0
|
| 155 |
+
# via sympy
|
| 156 |
+
networkx==3.5
|
| 157 |
+
# via
|
| 158 |
+
# scikit-image
|
| 159 |
+
# torch
|
| 160 |
+
numba==0.61.2
|
| 161 |
+
# via openai-whisper
|
| 162 |
+
numpy==1.26.4
|
| 163 |
+
# via
|
| 164 |
+
# faiss-gpu-cu12
|
| 165 |
+
# imageio
|
| 166 |
+
# katna
|
| 167 |
+
# moviepy
|
| 168 |
+
# numba
|
| 169 |
+
# openai-whisper
|
| 170 |
+
# opencv-contrib-python
|
| 171 |
+
# opencv-python
|
| 172 |
+
# pandas
|
| 173 |
+
# scikit-image
|
| 174 |
+
# scikit-learn
|
| 175 |
+
# scipy
|
| 176 |
+
# tifffile
|
| 177 |
+
# torchvision
|
| 178 |
+
# transformers
|
| 179 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 180 |
+
# via
|
| 181 |
+
# faiss-gpu-cu12
|
| 182 |
+
# nvidia-cudnn-cu12
|
| 183 |
+
# nvidia-cusolver-cu12
|
| 184 |
+
# torch
|
| 185 |
+
nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 186 |
+
# via torch
|
| 187 |
+
nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 188 |
+
# via torch
|
| 189 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 190 |
+
# via
|
| 191 |
+
# faiss-gpu-cu12
|
| 192 |
+
# torch
|
| 193 |
+
nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 194 |
+
# via torch
|
| 195 |
+
nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 196 |
+
# via torch
|
| 197 |
+
nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 198 |
+
# via torch
|
| 199 |
+
nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 200 |
+
# via torch
|
| 201 |
+
nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 202 |
+
# via torch
|
| 203 |
+
nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 204 |
+
# via
|
| 205 |
+
# nvidia-cusolver-cu12
|
| 206 |
+
# torch
|
| 207 |
+
nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 208 |
+
# via torch
|
| 209 |
+
nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 210 |
+
# via torch
|
| 211 |
+
nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 212 |
+
# via
|
| 213 |
+
# nvidia-cufft-cu12
|
| 214 |
+
# nvidia-cusolver-cu12
|
| 215 |
+
# nvidia-cusparse-cu12
|
| 216 |
+
# torch
|
| 217 |
+
nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 218 |
+
# via torch
|
| 219 |
+
open-clip-torch==2.32.0
|
| 220 |
+
# via acmmm25-grand-challenge-geolocation
|
| 221 |
+
openai-whisper==20250625
|
| 222 |
+
# via acmmm25-grand-challenge-geolocation
|
| 223 |
+
opencv-contrib-python==4.11.0.86
|
| 224 |
+
# via katna
|
| 225 |
+
opencv-python==4.11.0.86
|
| 226 |
+
# via acmmm25-grand-challenge-geolocation
|
| 227 |
+
packaging==25.0
|
| 228 |
+
# via
|
| 229 |
+
# faiss-gpu-cu12
|
| 230 |
+
# huggingface-hub
|
| 231 |
+
# lazy-loader
|
| 232 |
+
# scikit-image
|
| 233 |
+
# transformers
|
| 234 |
+
pandas==2.3.1
|
| 235 |
+
# via acmmm25-grand-challenge-geolocation
|
| 236 |
+
pillow==11.3.0
|
| 237 |
+
# via
|
| 238 |
+
# acmmm25-grand-challenge-geolocation
|
| 239 |
+
# imageio
|
| 240 |
+
# moviepy
|
| 241 |
+
# scikit-image
|
| 242 |
+
# torchvision
|
| 243 |
+
playwright==1.53.0
|
| 244 |
+
# via acmmm25-grand-challenge-geolocation
|
| 245 |
+
proglog==0.1.12
|
| 246 |
+
# via moviepy
|
| 247 |
+
proto-plus==1.26.1
|
| 248 |
+
# via
|
| 249 |
+
# google-api-core
|
| 250 |
+
# google-cloud-videointelligence
|
| 251 |
+
# google-cloud-vision
|
| 252 |
+
protobuf==6.31.1
|
| 253 |
+
# via
|
| 254 |
+
# google-api-core
|
| 255 |
+
# google-cloud-videointelligence
|
| 256 |
+
# google-cloud-vision
|
| 257 |
+
# googleapis-common-protos
|
| 258 |
+
# grpcio-status
|
| 259 |
+
# proto-plus
|
| 260 |
+
psutil==7.0.0
|
| 261 |
+
# via katna
|
| 262 |
+
pyasn1==0.6.1
|
| 263 |
+
# via
|
| 264 |
+
# pyasn1-modules
|
| 265 |
+
# rsa
|
| 266 |
+
pyasn1-modules==0.4.2
|
| 267 |
+
# via google-auth
|
| 268 |
+
pydantic==2.11.7
|
| 269 |
+
# via
|
| 270 |
+
# fastapi
|
| 271 |
+
# fastapi-cloud-cli
|
| 272 |
+
# google-genai
|
| 273 |
+
pydantic-core==2.33.2
|
| 274 |
+
# via pydantic
|
| 275 |
+
pyee==13.0.0
|
| 276 |
+
# via playwright
|
| 277 |
+
pygments==2.19.2
|
| 278 |
+
# via rich
|
| 279 |
+
pyproj==3.7.1
|
| 280 |
+
# via acmmm25-grand-challenge-geolocation
|
| 281 |
+
python-dateutil==2.9.0.post0
|
| 282 |
+
# via pandas
|
| 283 |
+
python-dotenv==1.1.1
|
| 284 |
+
# via
|
| 285 |
+
# acmmm25-grand-challenge-geolocation
|
| 286 |
+
# moviepy
|
| 287 |
+
# uvicorn
|
| 288 |
+
python-multipart==0.0.20
|
| 289 |
+
# via fastapi
|
| 290 |
+
pytz==2025.2
|
| 291 |
+
# via pandas
|
| 292 |
+
pyyaml==6.0.2
|
| 293 |
+
# via
|
| 294 |
+
# acmmm25-grand-challenge-geolocation
|
| 295 |
+
# huggingface-hub
|
| 296 |
+
# timm
|
| 297 |
+
# transformers
|
| 298 |
+
# uvicorn
|
| 299 |
+
regex==2024.11.6
|
| 300 |
+
# via
|
| 301 |
+
# open-clip-torch
|
| 302 |
+
# tiktoken
|
| 303 |
+
# transformers
|
| 304 |
+
requests==2.32.4
|
| 305 |
+
# via
|
| 306 |
+
# google-api-core
|
| 307 |
+
# google-genai
|
| 308 |
+
# huggingface-hub
|
| 309 |
+
# katna
|
| 310 |
+
# tiktoken
|
| 311 |
+
# transformers
|
| 312 |
+
rich==14.0.0
|
| 313 |
+
# via
|
| 314 |
+
# rich-toolkit
|
| 315 |
+
# typer
|
| 316 |
+
rich-toolkit==0.14.8
|
| 317 |
+
# via
|
| 318 |
+
# fastapi-cli
|
| 319 |
+
# fastapi-cloud-cli
|
| 320 |
+
rignore==0.6.2
|
| 321 |
+
# via fastapi-cloud-cli
|
| 322 |
+
rsa==4.9.1
|
| 323 |
+
# via google-auth
|
| 324 |
+
safetensors==0.5.3
|
| 325 |
+
# via
|
| 326 |
+
# open-clip-torch
|
| 327 |
+
# timm
|
| 328 |
+
# transformers
|
| 329 |
+
scikit-image==0.25.2
|
| 330 |
+
# via katna
|
| 331 |
+
scikit-learn==1.7.0
|
| 332 |
+
# via
|
| 333 |
+
# acmmm25-grand-challenge-geolocation
|
| 334 |
+
# katna
|
| 335 |
+
scipy==1.16.0
|
| 336 |
+
# via
|
| 337 |
+
# katna
|
| 338 |
+
# scikit-image
|
| 339 |
+
# scikit-learn
|
| 340 |
+
sentry-sdk==2.33.0
|
| 341 |
+
# via fastapi-cloud-cli
|
| 342 |
+
setuptools==80.9.0
|
| 343 |
+
# via
|
| 344 |
+
# torch
|
| 345 |
+
# triton
|
| 346 |
+
shellingham==1.5.4
|
| 347 |
+
# via typer
|
| 348 |
+
six==1.17.0
|
| 349 |
+
# via python-dateutil
|
| 350 |
+
sniffio==1.3.1
|
| 351 |
+
# via anyio
|
| 352 |
+
starlette==0.47.1
|
| 353 |
+
# via fastapi
|
| 354 |
+
sympy==1.14.0
|
| 355 |
+
# via torch
|
| 356 |
+
tenacity==8.5.0
|
| 357 |
+
# via google-genai
|
| 358 |
+
threadpoolctl==3.6.0
|
| 359 |
+
# via scikit-learn
|
| 360 |
+
tifffile==2025.6.11
|
| 361 |
+
# via scikit-image
|
| 362 |
+
tiktoken==0.9.0
|
| 363 |
+
# via openai-whisper
|
| 364 |
+
timm==1.0.17
|
| 365 |
+
# via open-clip-torch
|
| 366 |
+
tokenizers==0.21.2
|
| 367 |
+
# via transformers
|
| 368 |
+
torch==2.7.1
|
| 369 |
+
# via
|
| 370 |
+
# acmmm25-grand-challenge-geolocation
|
| 371 |
+
# open-clip-torch
|
| 372 |
+
# openai-whisper
|
| 373 |
+
# timm
|
| 374 |
+
# torchvision
|
| 375 |
+
torchvision==0.22.1
|
| 376 |
+
# via
|
| 377 |
+
# acmmm25-grand-challenge-geolocation
|
| 378 |
+
# open-clip-torch
|
| 379 |
+
# timm
|
| 380 |
+
tqdm==4.67.1
|
| 381 |
+
# via
|
| 382 |
+
# acmmm25-grand-challenge-geolocation
|
| 383 |
+
# huggingface-hub
|
| 384 |
+
# open-clip-torch
|
| 385 |
+
# openai-whisper
|
| 386 |
+
# proglog
|
| 387 |
+
# transformers
|
| 388 |
+
transformers==4.53.2
|
| 389 |
+
# via acmmm25-grand-challenge-geolocation
|
| 390 |
+
triton==3.3.1 ; (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'linux2'
|
| 391 |
+
# via
|
| 392 |
+
# openai-whisper
|
| 393 |
+
# torch
|
| 394 |
+
typer==0.16.0
|
| 395 |
+
# via
|
| 396 |
+
# fastapi-cli
|
| 397 |
+
# fastapi-cloud-cli
|
| 398 |
+
typing-extensions==4.14.1
|
| 399 |
+
# via
|
| 400 |
+
# anyio
|
| 401 |
+
# fastapi
|
| 402 |
+
# google-genai
|
| 403 |
+
# huggingface-hub
|
| 404 |
+
# pydantic
|
| 405 |
+
# pydantic-core
|
| 406 |
+
# pyee
|
| 407 |
+
# rich-toolkit
|
| 408 |
+
# starlette
|
| 409 |
+
# torch
|
| 410 |
+
# typer
|
| 411 |
+
# typing-inspection
|
| 412 |
+
typing-inspection==0.4.1
|
| 413 |
+
# via pydantic
|
| 414 |
+
tzdata==2025.2
|
| 415 |
+
# via pandas
|
| 416 |
+
urllib3==2.5.0
|
| 417 |
+
# via
|
| 418 |
+
# requests
|
| 419 |
+
# sentry-sdk
|
| 420 |
+
uvicorn==0.35.0
|
| 421 |
+
# via
|
| 422 |
+
# fastapi
|
| 423 |
+
# fastapi-cli
|
| 424 |
+
# fastapi-cloud-cli
|
| 425 |
+
uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
|
| 426 |
+
# via uvicorn
|
| 427 |
+
watchfiles==1.1.0
|
| 428 |
+
# via uvicorn
|
| 429 |
+
wcwidth==0.2.13
|
| 430 |
+
# via ftfy
|
| 431 |
+
websockets==15.0.1
|
| 432 |
+
# via
|
| 433 |
+
# google-genai
|
| 434 |
+
# uvicorn
|
src/data_processor.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import hashlib
|
| 6 |
+
import shutil
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import faiss
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from .prompt.fetch.content_fetch import fetch_links_to_json
|
| 15 |
+
from .prompt.fetch.satellite_fetch import fetch_satellite_image
|
| 16 |
+
from .prompt.preprocess.keyframe_extract import extract_and_save_keyframes
|
| 17 |
+
from .prompt.preprocess.video_transcribe import transcribe_video_directory
|
| 18 |
+
from .prompt.search.image_search import image_search_directory
|
| 19 |
+
from .prompt.search.index_search import save_results_to_json, search_index_directory
|
| 20 |
+
from .prompt.search.text_search import text_search_image, text_search_link
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("uvicorn.error")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DataProcessor:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
model: nn.Module,
|
| 29 |
+
input_dir: Path,
|
| 30 |
+
prompt_dir: Path,
|
| 31 |
+
cache_dir: Path,
|
| 32 |
+
image_dir: Path,
|
| 33 |
+
audio_dir: Path,
|
| 34 |
+
index_path: Path,
|
| 35 |
+
database_csv_path: Path,
|
| 36 |
+
device: torch.device,
|
| 37 |
+
):
|
| 38 |
+
self.input_dir = input_dir
|
| 39 |
+
self.prompt_dir = prompt_dir
|
| 40 |
+
self.cache_dir = cache_dir
|
| 41 |
+
self.image_dir = image_dir
|
| 42 |
+
self.audio_dir = audio_dir
|
| 43 |
+
self.model = model
|
| 44 |
+
self.device = device
|
| 45 |
+
self.database_csv_path = database_csv_path
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
self.index = faiss.read_index(str(index_path))
|
| 49 |
+
logger.info(f"✅ Successfully loaded FAISS index from: {index_path}")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
raise RuntimeError(f"Failed to load FAISS index from {index_path}: {e}")
|
| 52 |
+
|
| 53 |
+
self.image_extension = {
|
| 54 |
+
".jpg",
|
| 55 |
+
".jpeg",
|
| 56 |
+
".png",
|
| 57 |
+
".bmp",
|
| 58 |
+
".tiff",
|
| 59 |
+
".tif",
|
| 60 |
+
".webp",
|
| 61 |
+
}
|
| 62 |
+
self.video_extension = {
|
| 63 |
+
".mp4",
|
| 64 |
+
".avi",
|
| 65 |
+
".mov",
|
| 66 |
+
".mkv",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def __extract_keyframes(self):
|
| 70 |
+
"""
|
| 71 |
+
Extract keyframes from all videos in the input directory.
|
| 72 |
+
Put all images and keyframes into the prompt directory.
|
| 73 |
+
"""
|
| 74 |
+
output_dir = self.image_dir
|
| 75 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
# Determine starting index based on existing files
|
| 78 |
+
current_files = list(output_dir.glob("image_*.*"))
|
| 79 |
+
idx = len(current_files)
|
| 80 |
+
|
| 81 |
+
# Process images
|
| 82 |
+
for file_name in os.listdir(self.input_dir):
|
| 83 |
+
file_path = os.path.join(self.input_dir, file_name)
|
| 84 |
+
if os.path.isfile(file_path) and file_name.lower().endswith(
|
| 85 |
+
tuple(self.image_extension)
|
| 86 |
+
):
|
| 87 |
+
out_path = output_dir / f"image_{idx:03d}.jpg"
|
| 88 |
+
Image.open(file_path).convert("RGB").save(out_path)
|
| 89 |
+
idx += 1
|
| 90 |
+
|
| 91 |
+
# Process videos
|
| 92 |
+
for file_name in os.listdir(self.input_dir):
|
| 93 |
+
file_path = os.path.join(self.input_dir, file_name)
|
| 94 |
+
if os.path.isfile(file_path) and file_name.lower().endswith(
|
| 95 |
+
tuple(self.video_extension)
|
| 96 |
+
):
|
| 97 |
+
if idx is None:
|
| 98 |
+
idx = 0
|
| 99 |
+
idx = extract_and_save_keyframes(
|
| 100 |
+
video_path=file_path, output_dir=str(output_dir), start_index=idx
|
| 101 |
+
)
|
| 102 |
+
logger.info(f"✅ Extracted keyframes and images to: {output_dir}")
|
| 103 |
+
|
| 104 |
+
def __transcribe_videos(self):
|
| 105 |
+
"""
|
| 106 |
+
Transcribe all videos in the input directory.
|
| 107 |
+
Save transcripts into the prompt directory.
|
| 108 |
+
"""
|
| 109 |
+
audio_dir = self.audio_dir
|
| 110 |
+
os.makedirs(audio_dir, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
if audio_dir.is_dir() and any(audio_dir.iterdir()):
|
| 113 |
+
logger.info(f"🔄 Found existing transcripts in directory: {audio_dir}")
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
transcribe_video_directory(
|
| 117 |
+
video_dir=str(self.input_dir),
|
| 118 |
+
output_dir=str(audio_dir),
|
| 119 |
+
model_name="base", # Use the base Whisper model for transcription
|
| 120 |
+
)
|
| 121 |
+
logger.info(f"✅ Successfully transcribed videos to: {audio_dir}")
|
| 122 |
+
|
| 123 |
+
def __image_search(self):
|
| 124 |
+
"""
|
| 125 |
+
Perform image search on all images in the input directory.
|
| 126 |
+
Save search results into the prompt directory.
|
| 127 |
+
"""
|
| 128 |
+
image_dir = self.image_dir
|
| 129 |
+
|
| 130 |
+
if os.environ["IMGBB_API_KEY"] is None:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"IMGBB_API_KEY environment variable is not set or is None."
|
| 133 |
+
)
|
| 134 |
+
if os.environ["SCRAPINGDOG_API_KEY"] is None:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
"SCRAPINGDOG_API_KEY environment variable is not set or is None."
|
| 137 |
+
)
|
| 138 |
+
image_search_directory(
|
| 139 |
+
directory=str(image_dir),
|
| 140 |
+
output_dir=str(self.prompt_dir),
|
| 141 |
+
filename="metadata.json",
|
| 142 |
+
imgbb_key=os.environ["IMGBB_API_KEY"],
|
| 143 |
+
scrapingdog_key=os.environ["SCRAPINGDOG_API_KEY"],
|
| 144 |
+
max_workers=4,
|
| 145 |
+
target_links=20,
|
| 146 |
+
)
|
| 147 |
+
logger.info(f"✅ Successfully performed image search on: {image_dir}")
|
| 148 |
+
|
| 149 |
+
def __text_search(self):
|
| 150 |
+
"""
|
| 151 |
+
Perform text search with metadata to get related links.
|
| 152 |
+
"""
|
| 153 |
+
query = ""
|
| 154 |
+
metadata_file = self.prompt_dir / "metadata.json"
|
| 155 |
+
if not metadata_file.exists():
|
| 156 |
+
query = ""
|
| 157 |
+
else:
|
| 158 |
+
with open(metadata_file, "r") as f:
|
| 159 |
+
metadata = json.load(f)
|
| 160 |
+
description = metadata.get("description", "")
|
| 161 |
+
location = metadata.get("location", "")
|
| 162 |
+
query = f"{description} in {location}".strip()
|
| 163 |
+
|
| 164 |
+
text_search_link(
|
| 165 |
+
query=query,
|
| 166 |
+
output_dir=str(self.prompt_dir),
|
| 167 |
+
filename="text_search.json",
|
| 168 |
+
num_results=10,
|
| 169 |
+
api_key=os.environ["GOOGLE_CLOUD_API_KEY"],
|
| 170 |
+
cx=os.environ["GOOGLE_CSE_CX"],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
async def __fetch_related_link_content(
|
| 174 |
+
self, image_prediction: bool = True, text_prediction: bool = True
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Fetch related link content for all images and text in the prompt directory.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
async def fetch_and_save_links(links, output_filename):
|
| 181 |
+
if links:
|
| 182 |
+
await fetch_links_to_json(
|
| 183 |
+
links=list(links),
|
| 184 |
+
output_path=str(self.prompt_dir / output_filename),
|
| 185 |
+
max_content_length=5000,
|
| 186 |
+
)
|
| 187 |
+
logger.info(
|
| 188 |
+
f"Fetched content for {len(links)} links into {output_filename}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Image links
|
| 192 |
+
image_links = set()
|
| 193 |
+
image_search_file = self.prompt_dir / "metadata.json"
|
| 194 |
+
if image_prediction:
|
| 195 |
+
if not image_search_file.exists():
|
| 196 |
+
self.__image_search()
|
| 197 |
+
with open(image_search_file, "r") as f:
|
| 198 |
+
image_search_data = json.load(f)
|
| 199 |
+
image_links.update(image_search_data.get("all_links", []))
|
| 200 |
+
logger.info(f"Found {len(image_links)} image links to fetch content from.")
|
| 201 |
+
await fetch_and_save_links(image_links, "image_search_content.json")
|
| 202 |
+
|
| 203 |
+
# Text links
|
| 204 |
+
text_links = set()
|
| 205 |
+
text_search_file = self.prompt_dir / "text_search.json"
|
| 206 |
+
if text_prediction:
|
| 207 |
+
if not text_search_file.exists():
|
| 208 |
+
self.__text_search()
|
| 209 |
+
with open(text_search_file, "r") as f:
|
| 210 |
+
text_search_data = json.load(f)
|
| 211 |
+
text_links.update(filter(None, text_search_data.get("links", [])))
|
| 212 |
+
logger.info(f"Found {len(text_links)} text links to fetch content from.")
|
| 213 |
+
await fetch_and_save_links(text_links, "text_search_content.json")
|
| 214 |
+
|
| 215 |
+
if not image_links and not text_links:
|
| 216 |
+
logger.info("No links found in image or text search results.")
|
| 217 |
+
|
| 218 |
+
def __index_search(self):
|
| 219 |
+
"""
|
| 220 |
+
Perform FAISS index search on all images in the prompt directory.
|
| 221 |
+
Save search results into the report directory.
|
| 222 |
+
"""
|
| 223 |
+
if not self.index:
|
| 224 |
+
raise RuntimeError(
|
| 225 |
+
"FAISS index is not loaded. Cannot perform index search."
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
output_path = self.prompt_dir / "index_search.json"
|
| 229 |
+
if output_path.exists():
|
| 230 |
+
logger.info(
|
| 231 |
+
f"Index search results already exist at {output_path}, skipping search."
|
| 232 |
+
)
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
if not os.path.exists(self.database_csv_path):
|
| 236 |
+
raise FileNotFoundError(
|
| 237 |
+
f"Database CSV file not found: {self.database_csv_path}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
candidates_gps, reverse_gps = search_index_directory(
|
| 241 |
+
model=self.model,
|
| 242 |
+
device=self.device,
|
| 243 |
+
index=self.index,
|
| 244 |
+
image_dir=str(self.image_dir),
|
| 245 |
+
database_csv_path=str(self.database_csv_path),
|
| 246 |
+
top_k=20,
|
| 247 |
+
max_elements=20,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
save_results_to_json(candidates_gps, reverse_gps, str(output_path))
|
| 251 |
+
logger.info(
|
| 252 |
+
f"✅ Successfully performed index search. Results saved to: {output_path}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
async def __fetch_satellite_image_async(
|
| 256 |
+
self,
|
| 257 |
+
latitude: float,
|
| 258 |
+
longitude: float,
|
| 259 |
+
zoom: int,
|
| 260 |
+
output_path: Path,
|
| 261 |
+
) -> None:
|
| 262 |
+
"""
|
| 263 |
+
Asynchronously fetches a satellite image without blocking the event loop.
|
| 264 |
+
|
| 265 |
+
Runs the synchronous `fetch_satellite_image` function in a background thread.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
latitude (float): Latitude of the location.
|
| 269 |
+
longitude (float): Longitude of the location.
|
| 270 |
+
zoom (int): Zoom level of the satellite image.
|
| 271 |
+
output_path (Path): Path to save the image file.
|
| 272 |
+
"""
|
| 273 |
+
await asyncio.to_thread(
|
| 274 |
+
fetch_satellite_image,
|
| 275 |
+
latitude,
|
| 276 |
+
longitude,
|
| 277 |
+
zoom,
|
| 278 |
+
str(output_path),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
async def __search_images_async(
|
| 282 |
+
self,
|
| 283 |
+
location: str,
|
| 284 |
+
num_images: int,
|
| 285 |
+
api_key: str | None,
|
| 286 |
+
cse_cx: str | None,
|
| 287 |
+
output_dir: Path,
|
| 288 |
+
image_id_offset: int,
|
| 289 |
+
) -> list[str]:
|
| 290 |
+
"""
|
| 291 |
+
Asynchronously searches for images based on a text location query.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
location (str): Text location to search.
|
| 295 |
+
num_images (int): Number of images to fetch.
|
| 296 |
+
api_key (str): Google Cloud API key.
|
| 297 |
+
cse_cx (str): Google Custom Search Engine ID.
|
| 298 |
+
output_dir (Path): Directory where images will be saved.
|
| 299 |
+
image_id_offset (int): Offset for image filenames.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Any: The result of `text_search_image`, if it returns a value.
|
| 303 |
+
"""
|
| 304 |
+
return await asyncio.to_thread(
|
| 305 |
+
text_search_image,
|
| 306 |
+
location,
|
| 307 |
+
num_images,
|
| 308 |
+
api_key,
|
| 309 |
+
cse_cx,
|
| 310 |
+
str(output_dir),
|
| 311 |
+
image_id_offset,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def __compute_sha256(self, filepath: Path) -> str:
|
| 315 |
+
"""
|
| 316 |
+
Compute the SHA-256 hash of a file.
|
| 317 |
+
"""
|
| 318 |
+
if not filepath.is_file():
|
| 319 |
+
raise ValueError(f"File does not exist: {filepath}")
|
| 320 |
+
|
| 321 |
+
sha256 = hashlib.sha256()
|
| 322 |
+
with open(filepath, "rb") as f:
|
| 323 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 324 |
+
sha256.update(chunk)
|
| 325 |
+
return sha256.hexdigest()
|
| 326 |
+
|
| 327 |
+
def __compare_directories(self, dir1: Path, dir2: Path) -> bool:
|
| 328 |
+
"""
|
| 329 |
+
Compare two directories to check if they contain the same files with identical content.
|
| 330 |
+
Args:
|
| 331 |
+
dir1 (Path): First directory to compare.
|
| 332 |
+
dir2 (Path): Second directory to compare.
|
| 333 |
+
Returns:
|
| 334 |
+
bool: True if both directories contain the same files with identical content, False otherwise.
|
| 335 |
+
"""
|
| 336 |
+
if not dir1.is_dir() or not dir2.is_dir():
|
| 337 |
+
return False
|
| 338 |
+
|
| 339 |
+
files1 = sorted(p for p in dir1.iterdir() if p.is_file())
|
| 340 |
+
files2 = sorted(p for p in dir2.iterdir() if p.is_file())
|
| 341 |
+
|
| 342 |
+
# Check if filenames match exactly
|
| 343 |
+
names1 = {p.name for p in files1}
|
| 344 |
+
names2 = {p.name for p in files2}
|
| 345 |
+
if names1 != names2:
|
| 346 |
+
return False
|
| 347 |
+
|
| 348 |
+
# Compare each matching file
|
| 349 |
+
for filename in names1:
|
| 350 |
+
path1 = dir1 / filename
|
| 351 |
+
path2 = dir2 / filename
|
| 352 |
+
|
| 353 |
+
# Skip directories
|
| 354 |
+
if not path1.is_file() or not path2.is_file():
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
hash1 = self.__compute_sha256(path1)
|
| 358 |
+
hash2 = self.__compute_sha256(path2)
|
| 359 |
+
|
| 360 |
+
if hash1 != hash2:
|
| 361 |
+
return False # Found mismatch
|
| 362 |
+
return True # All matching files are identical
|
| 363 |
+
|
| 364 |
+
def __copy_directory(self, src: Path, dest: Path):
|
| 365 |
+
"""
|
| 366 |
+
Recursively copy all files from src to dest.
|
| 367 |
+
"""
|
| 368 |
+
if not src.is_dir():
|
| 369 |
+
raise ValueError(f"Source path is not a directory: {src}")
|
| 370 |
+
|
| 371 |
+
# Delete everything in dest first
|
| 372 |
+
if dest.exists():
|
| 373 |
+
for item in dest.iterdir():
|
| 374 |
+
if item.is_file() or item.is_symlink():
|
| 375 |
+
item.unlink()
|
| 376 |
+
elif item.is_dir():
|
| 377 |
+
shutil.rmtree(item)
|
| 378 |
+
|
| 379 |
+
# Ensure dest exists
|
| 380 |
+
dest.mkdir(parents=True, exist_ok=True)
|
| 381 |
+
|
| 382 |
+
for item in src.iterdir():
|
| 383 |
+
if item.is_dir():
|
| 384 |
+
self.__copy_directory(item, dest / item.name)
|
| 385 |
+
else:
|
| 386 |
+
dest_file = dest / item.name
|
| 387 |
+
if not dest_file.exists() or not self.__compare_directories(
|
| 388 |
+
item, dest_file
|
| 389 |
+
):
|
| 390 |
+
shutil.copy2(item, dest_file)
|
| 391 |
+
|
| 392 |
+
async def preprocess_input_data(
|
| 393 |
+
self,
|
| 394 |
+
image_prediction: bool = True,
|
| 395 |
+
text_prediction: bool = True,
|
| 396 |
+
):
|
| 397 |
+
"""
|
| 398 |
+
Preprocess all input data:
|
| 399 |
+
- Extract keyframes from videos.
|
| 400 |
+
- Transcribe videos.
|
| 401 |
+
- Fetch related link content from images.
|
| 402 |
+
Save images and extracted keyframes into the output directory
|
| 403 |
+
"""
|
| 404 |
+
os.makedirs(self.prompt_dir, exist_ok=True)
|
| 405 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 406 |
+
|
| 407 |
+
cache_dir_input = self.cache_dir / "input_data"
|
| 408 |
+
cache_dir_prompt = self.cache_dir / "prompt_data"
|
| 409 |
+
if self.__compare_directories(self.input_dir, cache_dir_input):
|
| 410 |
+
logger.info("Input data already processed, skipping...")
|
| 411 |
+
self.__copy_directory(cache_dir_prompt, self.prompt_dir)
|
| 412 |
+
return
|
| 413 |
+
else:
|
| 414 |
+
logger.info("Processing input data...")
|
| 415 |
+
|
| 416 |
+
metadata_dest = self.prompt_dir / "metadata.json"
|
| 417 |
+
if not metadata_dest.exists():
|
| 418 |
+
for file in os.listdir(self.input_dir):
|
| 419 |
+
if file.endswith(".json"):
|
| 420 |
+
file_path = os.path.join(self.input_dir, file)
|
| 421 |
+
with open(file_path, "r") as src_file:
|
| 422 |
+
with open(metadata_dest, "w") as dest_file:
|
| 423 |
+
dest_file.write(src_file.read())
|
| 424 |
+
break
|
| 425 |
+
|
| 426 |
+
self.__extract_keyframes()
|
| 427 |
+
self.__transcribe_videos()
|
| 428 |
+
await self.__fetch_related_link_content(
|
| 429 |
+
image_prediction=image_prediction, text_prediction=text_prediction
|
| 430 |
+
)
|
| 431 |
+
self.__index_search()
|
| 432 |
+
|
| 433 |
+
logger.info("✅ Preprocessing completed")
|
| 434 |
+
logger.info(f"Saving processed data to cache directory: {self.cache_dir}")
|
| 435 |
+
self.__copy_directory(self.input_dir, cache_dir_input)
|
| 436 |
+
self.__copy_directory(self.prompt_dir, cache_dir_prompt)
|
| 437 |
+
|
| 438 |
+
async def prepare_location_images(
|
| 439 |
+
self,
|
| 440 |
+
prediction: dict,
|
| 441 |
+
image_prediction: bool = True,
|
| 442 |
+
text_prediction: bool = True,
|
| 443 |
+
) -> int:
|
| 444 |
+
"""
|
| 445 |
+
Prepare verification data from the prediction with parallel fetching.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
prediction (dict): Prediction dictionary with latitude, longitude, location, reason, and metadata
|
| 449 |
+
image_prediction (bool): Whether to include original images in verification
|
| 450 |
+
text_prediction (bool): Whether to include text-based verification
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
int: Satellite image ID for reference in prompts
|
| 454 |
+
"""
|
| 455 |
+
image_dir = self.image_dir
|
| 456 |
+
satellite_image_id = len(list(self.image_dir.glob("image_*.*")))
|
| 457 |
+
|
| 458 |
+
# Execute both operations in parallel
|
| 459 |
+
logger.info("🔄 Fetching satellite image and location images in parallel...")
|
| 460 |
+
|
| 461 |
+
# Ensure required API keys are present
|
| 462 |
+
if not os.environ.get("GOOGLE_CLOUD_API_KEY"):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
"GOOGLE_CLOUD_API_KEY environment variable is not set or is None."
|
| 465 |
+
)
|
| 466 |
+
if not os.environ.get("GOOGLE_CSE_CX"):
|
| 467 |
+
raise ValueError(
|
| 468 |
+
"GOOGLE_CSE_CX environment variable is not set or is None."
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
await asyncio.gather(
|
| 472 |
+
self.__fetch_satellite_image_async(
|
| 473 |
+
prediction["latitude"],
|
| 474 |
+
prediction["longitude"],
|
| 475 |
+
zoom=200,
|
| 476 |
+
output_path=image_dir / f"image_{satellite_image_id:03d}.jpg",
|
| 477 |
+
),
|
| 478 |
+
self.__search_images_async(
|
| 479 |
+
location=prediction["location"],
|
| 480 |
+
num_images=5,
|
| 481 |
+
api_key=os.environ["GOOGLE_CLOUD_API_KEY"],
|
| 482 |
+
cse_cx=os.environ["GOOGLE_CSE_CX"],
|
| 483 |
+
output_dir=image_dir,
|
| 484 |
+
image_id_offset=satellite_image_id + 1,
|
| 485 |
+
),
|
| 486 |
+
)
|
| 487 |
+
logger.info("✅ Verification data preparation completed")
|
| 488 |
+
return satellite_image_id
|
src/g3/G3.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import cast
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
|
| 5 |
+
|
| 6 |
+
from .locationencoder import LocationEncoder
|
| 7 |
+
|
| 8 |
+
class G3(torch.nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
device: str,
|
| 12 |
+
positional_encoding_type: str = "sh",
|
| 13 |
+
neural_network_type: str = "siren",
|
| 14 |
+
hparams: dict | None = None,
|
| 15 |
+
):
|
| 16 |
+
super(G3, self).__init__()
|
| 17 |
+
self.device = device
|
| 18 |
+
|
| 19 |
+
clip_model = cast(CLIPModel, CLIPModel.from_pretrained("openai/clip-vit-large-patch14"))
|
| 20 |
+
self.vision_model = clip_model.vision_model
|
| 21 |
+
self.text_model = clip_model.text_model
|
| 22 |
+
self.vision_processor = cast(CLIPImageProcessor, CLIPImageProcessor.from_pretrained(
|
| 23 |
+
"openai/clip-vit-large-patch14"
|
| 24 |
+
))
|
| 25 |
+
self.text_processor = cast(CLIPTokenizer, CLIPTokenizer.from_pretrained(
|
| 26 |
+
"openai/clip-vit-large-patch14"
|
| 27 |
+
))
|
| 28 |
+
self.vision_projection = clip_model.visual_projection
|
| 29 |
+
self.text_projection = clip_model.text_projection
|
| 30 |
+
|
| 31 |
+
self.logit_scale1 = nn.Parameter(torch.tensor(3.99))
|
| 32 |
+
self.logit_scale2 = nn.Parameter(torch.tensor(3.99))
|
| 33 |
+
self.logit_scale3 = nn.Parameter(torch.tensor(3.99))
|
| 34 |
+
|
| 35 |
+
self.location_encoder = LocationEncoder(
|
| 36 |
+
positional_encoding_type=positional_encoding_type.split("_")[0],
|
| 37 |
+
neural_network_type=neural_network_type,
|
| 38 |
+
hparams=hparams,
|
| 39 |
+
device=device,
|
| 40 |
+
) # output batch_size, 3, 512
|
| 41 |
+
self.vision_projection_else_1 = nn.Sequential(
|
| 42 |
+
nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
|
| 43 |
+
)
|
| 44 |
+
self.text_projection_else = nn.Sequential(
|
| 45 |
+
nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.vision_projection_else_2 = nn.Sequential(
|
| 49 |
+
nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768)
|
| 50 |
+
)
|
| 51 |
+
self.location_projection_else = nn.Sequential(
|
| 52 |
+
nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 768)
|
| 53 |
+
)
|
| 54 |
+
# output_dim = 512 if hparams is None else hparams["output_dim"]
|
| 55 |
+
# self.location_projection_else = nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, 768))
|
| 56 |
+
|
| 57 |
+
# freeze CLIP
|
| 58 |
+
self.vision_model.requires_grad_(False)
|
| 59 |
+
self.vision_projection.requires_grad_(False)
|
| 60 |
+
self.text_model.requires_grad_(False)
|
| 61 |
+
self.text_projection.requires_grad_(False)
|
| 62 |
+
|
| 63 |
+
def forward(self, images, texts, longitude, latitude):
|
| 64 |
+
vision_output = self.vision_model(images)[1]
|
| 65 |
+
text_output = self.text_model(**texts)[1]
|
| 66 |
+
image_embeds = self.vision_projection(vision_output)
|
| 67 |
+
text_embeds = self.text_projection(text_output) # batch_size, 512
|
| 68 |
+
this_batch_locations = torch.stack((latitude, longitude), dim=1)
|
| 69 |
+
location_embeds = self.location_encoder(this_batch_locations)
|
| 70 |
+
|
| 71 |
+
# phase _1
|
| 72 |
+
image_embeds_1 = self.vision_projection_else_1(image_embeds)
|
| 73 |
+
text_embeds_1 = self.text_projection_else(
|
| 74 |
+
text_embeds.reshape(text_embeds.shape[0], -1)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# normalized features
|
| 78 |
+
image_embeds_1 = image_embeds_1 / image_embeds_1.norm(p=2, dim=-1, keepdim=True)
|
| 79 |
+
text_embeds_1 = text_embeds_1 / text_embeds_1.norm(p=2, dim=-1, keepdim=True)
|
| 80 |
+
|
| 81 |
+
# image with texts
|
| 82 |
+
logit_scale = self.logit_scale1.exp()
|
| 83 |
+
logits_per_texts_with_images = (
|
| 84 |
+
torch.matmul(text_embeds_1, image_embeds_1.t()) * logit_scale
|
| 85 |
+
)
|
| 86 |
+
logits_per_images_with_texts = logits_per_texts_with_images.t()
|
| 87 |
+
loss_phase_1 = self.clip_loss(logits_per_texts_with_images)
|
| 88 |
+
|
| 89 |
+
# phase _2
|
| 90 |
+
image_embeds_2 = self.vision_projection_else_2(image_embeds)
|
| 91 |
+
location_embeds_2 = self.location_projection_else(
|
| 92 |
+
location_embeds.reshape(location_embeds.shape[0], -1)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# normalized features
|
| 96 |
+
image_embeds_2 = image_embeds_2 / image_embeds_2.norm(p=2, dim=-1, keepdim=True)
|
| 97 |
+
location_embeds_2 = location_embeds_2 / location_embeds_2.norm(
|
| 98 |
+
p=2, dim=-1, keepdim=True
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# image with location
|
| 102 |
+
logit_scale = self.logit_scale2.exp()
|
| 103 |
+
logits_per_locations_with_images = (
|
| 104 |
+
torch.matmul(location_embeds_2, image_embeds_2.t()) * logit_scale
|
| 105 |
+
)
|
| 106 |
+
logits_per_images_with_locations = logits_per_locations_with_images.t()
|
| 107 |
+
loss_phase_2 = None
|
| 108 |
+
loss_phase_2 = self.clip_loss(logits_per_locations_with_images)
|
| 109 |
+
|
| 110 |
+
loss = loss_phase_1 + loss_phase_2
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"logits_per_texts_with_images": logits_per_texts_with_images,
|
| 114 |
+
"logits_per_images_with_texts": logits_per_images_with_texts,
|
| 115 |
+
"logits_per_locations_with_images": logits_per_locations_with_images,
|
| 116 |
+
"logits_per_images_with_locations": logits_per_images_with_locations,
|
| 117 |
+
"logits_per_locations_with_texts": None,
|
| 118 |
+
"logits_per_texts_with_locations": None,
|
| 119 |
+
"loss": loss,
|
| 120 |
+
"vision_output": vision_output,
|
| 121 |
+
"text_output": text_output,
|
| 122 |
+
"image_embeds": image_embeds,
|
| 123 |
+
"text_embeds": text_embeds,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def contrastive_loss(self, logits: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
return nn.functional.cross_entropy(
|
| 128 |
+
logits, torch.arange(len(logits), device=logits.device)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def clip_loss(self, similarity: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
caption_loss = self.contrastive_loss(similarity)
|
| 133 |
+
image_loss = self.contrastive_loss(similarity.t())
|
| 134 |
+
return (caption_loss + image_loss) / 2.0
|
src/g3/dataset.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import tarfile
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Callable, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
import transformers
|
| 13 |
+
from PIL import Image, ImageFile
|
| 14 |
+
from torch.utils.data import DataLoader, get_worker_info
|
| 15 |
+
from torchvision.datasets import VisionDataset
|
| 16 |
+
from torchvision.io import ImageReadMode, read_image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import (
|
| 19 |
+
CLIPImageProcessor,
|
| 20 |
+
CLIPModel,
|
| 21 |
+
CLIPTextModel,
|
| 22 |
+
CLIPTokenizer,
|
| 23 |
+
CLIPVisionModel,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow truncated images to be loaded
|
| 27 |
+
|
| 28 |
+
from io import BytesIO
|
| 29 |
+
from typing import Any, Dict, Iterator, Optional, Tuple
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torchvision.transforms as T
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from huggingface_hub import login
|
| 35 |
+
from PIL import Image
|
| 36 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
| 37 |
+
|
| 38 |
+
__all__ = [
|
| 39 |
+
"MP16StreamingDataset",
|
| 40 |
+
"mp16_collate",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MP16StreamingDataset(IterableDataset):
|
| 45 |
+
"""Stream **MP‑16** samples from the HuggingFace Hub and yield a simple
|
| 46 |
+
tuple per example::
|
| 47 |
+
|
| 48 |
+
(image, text, longitude, latitude)
|
| 49 |
+
|
| 50 |
+
* **image** – either a tensor (``C×H×W``) if *vision_processor* is set or if
|
| 51 |
+
the fallback transform is used, otherwise a PIL image.
|
| 52 |
+
* **text** – caption string (either provided by the dataset or generated
|
| 53 |
+
from location fields).
|
| 54 |
+
* **longitude**, **latitude** – floats.
|
| 55 |
+
|
| 56 |
+
The class is an :class:`torch.utils.data.IterableDataset`, so wrap it in a
|
| 57 |
+
:class:`~torch.utils.data.DataLoader` for batching.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
repo_id: str = "tduongvn/MP16-Pro-shards",
|
| 63 |
+
split: str = "train",
|
| 64 |
+
vision_processor: Optional[Any] = None,
|
| 65 |
+
shuffle_buffer: int = 10_000,
|
| 66 |
+
HF_TOKEN: Optional[str] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.repo_id = repo_id
|
| 70 |
+
self.split = split
|
| 71 |
+
self.vision_processor = vision_processor
|
| 72 |
+
self.shuffle_buffer = shuffle_buffer
|
| 73 |
+
self.HF_TOKEN = HF_TOKEN
|
| 74 |
+
|
| 75 |
+
# Base transform when we *don't* have a fancy processor
|
| 76 |
+
self.fallback_transform = T.Compose(
|
| 77 |
+
[
|
| 78 |
+
T.RandomHorizontalFlip(),
|
| 79 |
+
T.RandomResizedCrop(size=224),
|
| 80 |
+
T.ToTensor(),
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Prepare an initial dataset iterator for the main process
|
| 85 |
+
self._base_iter = self._new_iterator()
|
| 86 |
+
|
| 87 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 88 |
+
# Internals ─┘
|
| 89 |
+
|
| 90 |
+
def _new_iterator(self):
|
| 91 |
+
if self.HF_TOKEN is not None:
|
| 92 |
+
login(token=self.HF_TOKEN)
|
| 93 |
+
return (
|
| 94 |
+
load_dataset(self.repo_id, split=self.split, streaming=True)
|
| 95 |
+
.shuffle(buffer_size=self.shuffle_buffer)
|
| 96 |
+
.__iter__()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def _decode_image(self, img_bytes):
|
| 100 |
+
"""bytes → PIL.Image or tensor (if processor is set)."""
|
| 101 |
+
img = Image.open(BytesIO(img_bytes)).convert("RGB")
|
| 102 |
+
if self.vision_processor is not None:
|
| 103 |
+
return self.vision_processor(images=img, return_tensors="pt")[
|
| 104 |
+
"pixel_values"
|
| 105 |
+
].squeeze(0)
|
| 106 |
+
return self.fallback_transform(img)
|
| 107 |
+
|
| 108 |
+
def _caption(self, ex_json: Dict[str, Any]) -> str:
|
| 109 |
+
parts = [ex_json.get(k) for k in ("city", "state", "country") if ex_json.get(k)]
|
| 110 |
+
return "A street view photo taken in " + ", ".join(parts)
|
| 111 |
+
|
| 112 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 113 |
+
# IterableDataset API ─┘
|
| 114 |
+
|
| 115 |
+
def __iter__(self) -> Iterator[Tuple[Any, str, float, float]]:
|
| 116 |
+
# Each DataLoader worker gets its own iterator to avoid state clashes.
|
| 117 |
+
worker = get_worker_info()
|
| 118 |
+
iterator = self._new_iterator() if worker is not None else self._base_iter
|
| 119 |
+
|
| 120 |
+
for ex in iterator:
|
| 121 |
+
# Dataset structure: {'jpg': <PIL or bytes>, 'json': {...}, ...}
|
| 122 |
+
img_field = ex["jpg"]
|
| 123 |
+
if isinstance(img_field, Image.Image):
|
| 124 |
+
img = img_field.convert("RGB")
|
| 125 |
+
if self.vision_processor is not None:
|
| 126 |
+
img = self.vision_processor(images=img, return_tensors="pt")[
|
| 127 |
+
"pixel_values"
|
| 128 |
+
].squeeze(0)
|
| 129 |
+
else:
|
| 130 |
+
img = self.fallback_transform(img)
|
| 131 |
+
else: # bytes
|
| 132 |
+
img = self._decode_image(img_field)
|
| 133 |
+
|
| 134 |
+
meta = ex["json"] if "json" in ex else {}
|
| 135 |
+
lon = float(meta.get("lon", meta.get("LON")))
|
| 136 |
+
lat = float(meta.get("lat", meta.get("LAT")))
|
| 137 |
+
text = meta.get("text") or self._caption(meta)
|
| 138 |
+
|
| 139 |
+
yield img, text, lon, lat
|
| 140 |
+
|
| 141 |
+
# No __len__ – this is a stream.
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 145 |
+
# Collate ─┘
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def make_mp16_collate(text_processor):
|
| 149 |
+
def collate(batch):
|
| 150 |
+
images, texts, lons, lats = zip(*batch)
|
| 151 |
+
|
| 152 |
+
images = torch.stack(images) # (B, C, H, W)
|
| 153 |
+
|
| 154 |
+
token_out = text_processor(
|
| 155 |
+
list(texts),
|
| 156 |
+
padding="longest",
|
| 157 |
+
truncation=True,
|
| 158 |
+
max_length=77,
|
| 159 |
+
return_tensors="pt",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
lons = torch.tensor(lons, dtype=torch.float32)
|
| 163 |
+
lats = torch.tensor(lats, dtype=torch.float32)
|
| 164 |
+
|
| 165 |
+
return images, token_out, lons, lats
|
| 166 |
+
|
| 167 |
+
return collate
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MP16Dataset(VisionDataset):
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
root_path="data/mp16/",
|
| 174 |
+
text_data_path="MP16_Pro_places365.csv",
|
| 175 |
+
image_data_path="mp-16-images.tar",
|
| 176 |
+
member_info_path="tar_index.pkl",
|
| 177 |
+
vision_processor=None,
|
| 178 |
+
text_processor=None,
|
| 179 |
+
):
|
| 180 |
+
super().__init__(self)
|
| 181 |
+
self.root_path = root_path
|
| 182 |
+
self.text_data_path = text_data_path
|
| 183 |
+
self.image_data_path = image_data_path
|
| 184 |
+
self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
|
| 185 |
+
self.text_data["IMG_ID"] = self.text_data["IMG_ID"].apply(
|
| 186 |
+
lambda x: x.replace("/", "_")
|
| 187 |
+
)
|
| 188 |
+
# self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
|
| 189 |
+
print("read text data success")
|
| 190 |
+
worker = get_worker_info()
|
| 191 |
+
worker = worker.id if worker else None
|
| 192 |
+
self.tar_obj = {worker: tarfile.open(os.path.join(root_path, image_data_path))}
|
| 193 |
+
# self.tar = tarfile.open(os.path.join(root_path, image_data_path))
|
| 194 |
+
|
| 195 |
+
if os.path.exists(os.path.join(self.root_path, member_info_path)):
|
| 196 |
+
with open(os.path.join(self.root_path, member_info_path), "rb") as f:
|
| 197 |
+
self.tar_index = pickle.load(f)
|
| 198 |
+
all_image_names = list(self.tar_index.keys())
|
| 199 |
+
print("load tar index success")
|
| 200 |
+
else:
|
| 201 |
+
print("no exist tar index success, need building...")
|
| 202 |
+
self.tar_index = {}
|
| 203 |
+
all_image_names = []
|
| 204 |
+
for member in tqdm(self.tar_obj[worker]):
|
| 205 |
+
if member.name.endswith(".jpg") and member.size > 5120:
|
| 206 |
+
self.tar_index[member.name.split("/")[1]] = member
|
| 207 |
+
all_image_names.append(member.name.split("/")[1])
|
| 208 |
+
print("tar index buidling success")
|
| 209 |
+
with open(os.path.join(self.root_path, member_info_path), "wb") as f:
|
| 210 |
+
pickle.dump(self.tar_index, f)
|
| 211 |
+
all_image_names = set(all_image_names)
|
| 212 |
+
|
| 213 |
+
self.text_data = self.text_data[self.text_data["country"].notnull()]
|
| 214 |
+
self.text_data = self.text_data[self.text_data["IMG_ID"].isin(all_image_names)]
|
| 215 |
+
print("data columns: ", self.text_data.shape[0])
|
| 216 |
+
|
| 217 |
+
# location from str to float
|
| 218 |
+
self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
|
| 219 |
+
self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
|
| 220 |
+
print("location from str to float success")
|
| 221 |
+
|
| 222 |
+
# image transform
|
| 223 |
+
self.transform = T.Resize(size=(512, 512))
|
| 224 |
+
self.transform_totensor = T.ToTensor()
|
| 225 |
+
|
| 226 |
+
self.vision_processor = vision_processor
|
| 227 |
+
self.text_processor = text_processor
|
| 228 |
+
|
| 229 |
+
# Define the contrast transforms here
|
| 230 |
+
self.contrast_transforms = T.Compose(
|
| 231 |
+
[
|
| 232 |
+
T.RandomHorizontalFlip(),
|
| 233 |
+
T.RandomResizedCrop(size=224),
|
| 234 |
+
T.RandomApply(
|
| 235 |
+
[
|
| 236 |
+
T.ColorJitter(
|
| 237 |
+
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1
|
| 238 |
+
)
|
| 239 |
+
],
|
| 240 |
+
p=0.8,
|
| 241 |
+
),
|
| 242 |
+
T.RandomGrayscale(p=0.2),
|
| 243 |
+
T.GaussianBlur(kernel_size=9),
|
| 244 |
+
T.ToTensor(),
|
| 245 |
+
# T.Normalize((0.5,), (0.5,))
|
| 246 |
+
]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# self.text_data.to_csv('/data/mp-16/MP16_Pro_filtered.csv', index=False)
|
| 250 |
+
|
| 251 |
+
def caption_generation(self, row):
|
| 252 |
+
pass
|
| 253 |
+
|
| 254 |
+
def __getitem__(self, index):
|
| 255 |
+
image_path = self.text_data.iloc[index]["IMG_ID"]
|
| 256 |
+
text = ""
|
| 257 |
+
neighbourhood, city, county, state, region, country, continent = (
|
| 258 |
+
self.text_data.iloc[index][
|
| 259 |
+
[
|
| 260 |
+
"neighbourhood",
|
| 261 |
+
"city",
|
| 262 |
+
"county",
|
| 263 |
+
"state",
|
| 264 |
+
"region",
|
| 265 |
+
"country",
|
| 266 |
+
"continent",
|
| 267 |
+
]
|
| 268 |
+
]
|
| 269 |
+
)
|
| 270 |
+
# location_elements = [element for element in [neighbourhood, city, state, country] if element is not np.nan and str(element) != 'nan']
|
| 271 |
+
location_elements = [
|
| 272 |
+
element
|
| 273 |
+
for element in [city, state, country]
|
| 274 |
+
if element is not np.nan and str(element) != "nan"
|
| 275 |
+
]
|
| 276 |
+
text = "A street view photo taken in " + ", ".join(location_elements)
|
| 277 |
+
|
| 278 |
+
longitude = self.text_data.iloc[index]["LON"]
|
| 279 |
+
latitude = self.text_data.iloc[index]["LAT"]
|
| 280 |
+
# read the image from self.tar
|
| 281 |
+
worker = get_worker_info()
|
| 282 |
+
worker = worker.id if worker else None
|
| 283 |
+
if worker not in self.tar_obj:
|
| 284 |
+
self.tar_obj[worker] = tarfile.open(
|
| 285 |
+
os.path.join(self.root_path, self.image_data_path)
|
| 286 |
+
)
|
| 287 |
+
image = self.tar_obj[worker].extractfile(self.tar_index[image_path])
|
| 288 |
+
image = Image.open(image)
|
| 289 |
+
|
| 290 |
+
if image.mode != "RGB":
|
| 291 |
+
image = image.convert("RGB")
|
| 292 |
+
|
| 293 |
+
if self.vision_processor:
|
| 294 |
+
image = self.vision_processor(images=image, return_tensors="pt")[
|
| 295 |
+
"pixel_values"
|
| 296 |
+
].reshape(3, 224, 224)
|
| 297 |
+
|
| 298 |
+
return image, text, longitude, latitude
|
| 299 |
+
|
| 300 |
+
def __len__(self):
|
| 301 |
+
return len(self.text_data)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class im2gps3kDataset(VisionDataset):
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
root_path="./data/im2gps3k",
|
| 308 |
+
text_data_path="im2gps3k_places365.csv",
|
| 309 |
+
image_data_path="images/",
|
| 310 |
+
vision_processor=None,
|
| 311 |
+
text_processor=None,
|
| 312 |
+
):
|
| 313 |
+
super().__init__(self)
|
| 314 |
+
print("start loading im2gps...")
|
| 315 |
+
self.root_path = root_path
|
| 316 |
+
self.text_data_path = text_data_path
|
| 317 |
+
self.image_data_path = image_data_path
|
| 318 |
+
self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
|
| 319 |
+
# self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
|
| 320 |
+
print("read text data success")
|
| 321 |
+
|
| 322 |
+
# location from str to float
|
| 323 |
+
self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
|
| 324 |
+
self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
|
| 325 |
+
print("location from str to float success")
|
| 326 |
+
|
| 327 |
+
self.vision_processor = vision_processor
|
| 328 |
+
self.text_processor = text_processor
|
| 329 |
+
|
| 330 |
+
self.tencrop = T.TenCrop(224)
|
| 331 |
+
|
| 332 |
+
def __getitem__(self, index):
|
| 333 |
+
image_path = self.text_data.iloc[index]["IMG_ID"]
|
| 334 |
+
text = image_path
|
| 335 |
+
|
| 336 |
+
longitude = self.text_data.iloc[index]["LON"]
|
| 337 |
+
latitude = self.text_data.iloc[index]["LAT"]
|
| 338 |
+
|
| 339 |
+
image = Image.open(
|
| 340 |
+
os.path.join(self.root_path, self.image_data_path, image_path)
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if image.mode != "RGB":
|
| 344 |
+
image = image.convert("RGB")
|
| 345 |
+
|
| 346 |
+
# image = self.tencrop(image) # for tencrop
|
| 347 |
+
|
| 348 |
+
if self.vision_processor:
|
| 349 |
+
image = self.vision_processor(images=image, return_tensors="pt")[
|
| 350 |
+
"pixel_values"
|
| 351 |
+
].reshape(-1, 224, 224)
|
| 352 |
+
|
| 353 |
+
return image, text, longitude, latitude
|
| 354 |
+
|
| 355 |
+
def __len__(self):
|
| 356 |
+
return len(self.text_data)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class yfcc4kDataset(VisionDataset):
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
root_path="./data/yfcc4k",
|
| 363 |
+
text_data_path="yfcc4k_places365.csv",
|
| 364 |
+
image_data_path="images/",
|
| 365 |
+
vision_processor=None,
|
| 366 |
+
text_processor=None,
|
| 367 |
+
):
|
| 368 |
+
super().__init__(self)
|
| 369 |
+
print("start loading yfcc4k...")
|
| 370 |
+
self.root_path = root_path
|
| 371 |
+
self.text_data_path = text_data_path
|
| 372 |
+
self.image_data_path = image_data_path
|
| 373 |
+
self.text_data = pd.read_csv(os.path.join(self.root_path, self.text_data_path))
|
| 374 |
+
# self.text_data = self.text_data[self.text_data['IMG_ID'].str.endswith('.jpg')] # only keep jpg images
|
| 375 |
+
print("read text data success")
|
| 376 |
+
|
| 377 |
+
# location from str to float
|
| 378 |
+
self.text_data.loc[:, "LAT"] = self.text_data["LAT"].astype(float)
|
| 379 |
+
self.text_data.loc[:, "LON"] = self.text_data["LON"].astype(float)
|
| 380 |
+
print("location from str to float success")
|
| 381 |
+
|
| 382 |
+
self.vision_processor = vision_processor
|
| 383 |
+
self.text_processor = text_processor
|
| 384 |
+
|
| 385 |
+
def __getitem__(self, index):
|
| 386 |
+
image_path = self.text_data.iloc[index]["IMG_ID"]
|
| 387 |
+
text = image_path
|
| 388 |
+
|
| 389 |
+
longitude = self.text_data.iloc[index]["LON"]
|
| 390 |
+
latitude = self.text_data.iloc[index]["LAT"]
|
| 391 |
+
|
| 392 |
+
image = Image.open(
|
| 393 |
+
os.path.join(self.root_path, self.image_data_path, image_path)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if image.mode != "RGB":
|
| 397 |
+
image = image.convert("RGB")
|
| 398 |
+
|
| 399 |
+
if self.vision_processor:
|
| 400 |
+
image = self.vision_processor(images=image, return_tensors="pt")[
|
| 401 |
+
"pixel_values"
|
| 402 |
+
].reshape(-1, 224, 224)
|
| 403 |
+
|
| 404 |
+
return image, text, longitude, latitude
|
| 405 |
+
|
| 406 |
+
def __len__(self):
|
| 407 |
+
return len(self.text_data)
|
src/g3/hparams.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sh_siren:
|
| 2 |
+
# legendre_polys: 30
|
| 3 |
+
# harmonics_calculation: analytic
|
| 4 |
+
# hidden_dim: 512
|
| 5 |
+
# num_layers: 3
|
| 6 |
+
# lr: 7.887855321604208e-05
|
| 7 |
+
# wd: 1.3475466222160537e-06
|
| 8 |
+
|
| 9 |
+
sh_siren:
|
| 10 |
+
legendre_polys: 40
|
| 11 |
+
harmonics_calculation: analytic
|
| 12 |
+
hidden_dim: 512
|
| 13 |
+
output_dim: 256
|
| 14 |
+
num_layers: 2
|
| 15 |
+
lr: 0.0001
|
| 16 |
+
wd: 0.01
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
projection_eep_rffmlp:
|
| 20 |
+
projection: eep
|
| 21 |
+
sigma:
|
| 22 |
+
- 1
|
| 23 |
+
- 16
|
| 24 |
+
- 256
|
| 25 |
+
hidden_dim: 1024
|
| 26 |
+
lr: 0.00003
|
| 27 |
+
wd: 0.000001
|
| 28 |
+
|
| 29 |
+
projection_mercator_rffmlp:
|
| 30 |
+
projection: mercator
|
| 31 |
+
sigma:
|
| 32 |
+
- 1
|
| 33 |
+
- 16
|
| 34 |
+
- 256
|
| 35 |
+
hidden_dim: 1024
|
| 36 |
+
lr: 0.00003
|
| 37 |
+
wd: 0.000001
|
| 38 |
+
|
| 39 |
+
projection_ecef_rffmlp:
|
| 40 |
+
projection: ecef
|
| 41 |
+
sigma:
|
| 42 |
+
- 1
|
| 43 |
+
- 16
|
| 44 |
+
- 256
|
| 45 |
+
hidden_dim: 1024
|
| 46 |
+
lr: 0.00003
|
| 47 |
+
wd: 0.000001
|
| 48 |
+
|
| 49 |
+
projection_eep_rffmlp:
|
| 50 |
+
projection: eep
|
| 51 |
+
sigma:
|
| 52 |
+
- 1
|
| 53 |
+
- 16
|
| 54 |
+
- 256
|
| 55 |
+
hidden_dim: 1024
|
| 56 |
+
lr: 0.00003
|
| 57 |
+
wd: 0.000001
|
src/g3/locationencoder.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .nn.mlp import MLP
|
| 4 |
+
from .nn.rff_mlp import RFFMLP
|
| 5 |
+
from .nn.siren import SirenNet
|
| 6 |
+
from .pe.projection import Projection
|
| 7 |
+
from .pe.projection_rff import ProjectionRFF
|
| 8 |
+
from .pe.spherical_harmonics import SphericalHarmonics
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_positional_encoding(positional_encoding_type, hparams, device="cuda"):
|
| 12 |
+
"""
|
| 13 |
+
Returns a positional encoding module based on the specified encoding type.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
encoding_type (str): The type of positional encoding to use. Options are 'rff', 'siren', 'sh', 'capsule'.
|
| 17 |
+
input_dim (int): The input dimension for the positional encoding.
|
| 18 |
+
output_dim (int): The output dimension for the positional encoding.
|
| 19 |
+
hparams: Additional arguments for specific encoding types.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
nn.Module: The positional encoding module.
|
| 23 |
+
"""
|
| 24 |
+
if positional_encoding_type == "projectionrff":
|
| 25 |
+
return ProjectionRFF(
|
| 26 |
+
projection=hparams["projection"],
|
| 27 |
+
sigma=hparams["sigma"],
|
| 28 |
+
hparams=hparams,
|
| 29 |
+
device=device,
|
| 30 |
+
)
|
| 31 |
+
elif positional_encoding_type == "projection":
|
| 32 |
+
return Projection(
|
| 33 |
+
projection=hparams["projection"], hparams=hparams, device=device
|
| 34 |
+
)
|
| 35 |
+
elif positional_encoding_type == "sh":
|
| 36 |
+
return SphericalHarmonics(
|
| 37 |
+
legendre_polys=hparams["legendre_polys"],
|
| 38 |
+
harmonics_calculation=hparams["harmonics_calculation"],
|
| 39 |
+
hparams=hparams,
|
| 40 |
+
device=device,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported encoding type: {positional_encoding_type}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_neural_network(
|
| 47 |
+
neural_network_type: str,
|
| 48 |
+
input_dim: int,
|
| 49 |
+
hparams: dict,
|
| 50 |
+
device="cuda",
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Returns a neural network module based on the specified network type.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
neural_network_type (str): The type of neural network to use. Options are 'siren'.
|
| 57 |
+
input_dim (int): The input dimension for the neural network.
|
| 58 |
+
output_dim (int): The output dimension for the neural network.
|
| 59 |
+
hparams: Additional arguments for specific network types.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
nn.Module: The neural network module.
|
| 63 |
+
"""
|
| 64 |
+
if neural_network_type == "siren":
|
| 65 |
+
return SirenNet(
|
| 66 |
+
input_dim=input_dim,
|
| 67 |
+
output_dim=hparams["output_dim"],
|
| 68 |
+
hidden_dim=hparams["hidden_dim"],
|
| 69 |
+
num_layers=hparams["num_layers"],
|
| 70 |
+
hparams=hparams,
|
| 71 |
+
device=device,
|
| 72 |
+
)
|
| 73 |
+
elif neural_network_type == "mlp":
|
| 74 |
+
return MLP(
|
| 75 |
+
input_dim=input_dim,
|
| 76 |
+
hidden_dim=hparams["hidden_dim"],
|
| 77 |
+
hparams=hparams,
|
| 78 |
+
device=device,
|
| 79 |
+
)
|
| 80 |
+
elif neural_network_type == "rffmlp":
|
| 81 |
+
return RFFMLP(
|
| 82 |
+
input_dim=input_dim,
|
| 83 |
+
hidden_dim=hparams["hidden_dim"],
|
| 84 |
+
sigma=hparams["sigma"],
|
| 85 |
+
hparams=hparams,
|
| 86 |
+
device=device,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Unsupported network type: {neural_network_type}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LocationEncoder(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
positional_encoding_type: str = "sh",
|
| 96 |
+
neural_network_type: str = "siren",
|
| 97 |
+
hparams: dict | None = None,
|
| 98 |
+
device: str = "cuda",
|
| 99 |
+
):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.device = device
|
| 102 |
+
|
| 103 |
+
self.position_encoder = get_positional_encoding(
|
| 104 |
+
positional_encoding_type=positional_encoding_type,
|
| 105 |
+
hparams=hparams,
|
| 106 |
+
device=device,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if hparams is None:
|
| 110 |
+
hparams = {}
|
| 111 |
+
|
| 112 |
+
self.neural_network = nn.ModuleList(
|
| 113 |
+
[
|
| 114 |
+
get_neural_network(
|
| 115 |
+
neural_network_type, input_dim=dim, hparams=hparams, device=device
|
| 116 |
+
)
|
| 117 |
+
for dim in self.position_encoder.embedding_dim
|
| 118 |
+
]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
embedding = self.position_encoder(x)
|
| 123 |
+
|
| 124 |
+
if embedding.ndim == 2:
|
| 125 |
+
# If the embedding is (batch, n), we need to add a dimension
|
| 126 |
+
embedding = embedding.unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
location_features = torch.zeros(embedding.shape[1], 512).to(self.device)
|
| 129 |
+
|
| 130 |
+
for nn, e in zip(self.neural_network, embedding):
|
| 131 |
+
location_features += nn(e)
|
| 132 |
+
|
| 133 |
+
return location_features
|
src/g3/nn/mlp.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
class MLP(nn.Module):
|
| 4 |
+
"""Multi-layer perceptron (MLP) with batch normalization and ReLU activation."""
|
| 5 |
+
|
| 6 |
+
def __init__(self, input_dim=512, hidden_dim=1024, output_dim=512, hparams=None, device='cuda'):
|
| 7 |
+
super(MLP, self).__init__()
|
| 8 |
+
self.device = device
|
| 9 |
+
self.capsule = nn.Sequential(nn.Linear(input_dim, hidden_dim),
|
| 10 |
+
nn.ReLU(),
|
| 11 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 12 |
+
nn.ReLU(),
|
| 13 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 14 |
+
nn.ReLU())
|
| 15 |
+
self.head = nn.Sequential(nn.Linear(hidden_dim, output_dim))
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.capsule(x)
|
| 19 |
+
x = self.head(x)
|
| 20 |
+
return x
|
src/g3/nn/rff_mlp.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
from ..rff.layers import GaussianEncoding
|
| 4 |
+
|
| 5 |
+
class LocationEncoderCapsule(nn.Module):
|
| 6 |
+
def __init__(self, input_dim=2, hidden_dim=1024, output_dim=512, sigma=2**0):
|
| 7 |
+
super(LocationEncoderCapsule, self).__init__()
|
| 8 |
+
rff_encoding = GaussianEncoding(sigma=sigma, input_size=input_dim, encoded_size=int(output_dim/2))
|
| 9 |
+
self.capsule = nn.Sequential(rff_encoding,
|
| 10 |
+
nn.Linear(output_dim, hidden_dim),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 13 |
+
nn.ReLU(),
|
| 14 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 15 |
+
nn.ReLU())
|
| 16 |
+
self.head = nn.Sequential(nn.Linear(hidden_dim, output_dim))
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.capsule(x)
|
| 20 |
+
x = self.head(x)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
class RFFMLP(nn.Module):
|
| 24 |
+
"""Multi-layer perceptron (MLP) with batch normalization and ReLU activation."""
|
| 25 |
+
def __init__(self, input_dim=2, hidden_dim=1024, output_dim=512, sigma=[2**0, 2**4, 2**8], hparams=None, device='cuda'):
|
| 26 |
+
super(RFFMLP, self).__init__()
|
| 27 |
+
self.num_hierarchies = len(sigma)
|
| 28 |
+
self.device = device
|
| 29 |
+
|
| 30 |
+
for i, s in enumerate(sigma):
|
| 31 |
+
self.add_module('LocEnc' + str(i), LocationEncoderCapsule(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, sigma=s))
|
| 32 |
+
|
| 33 |
+
def forward(self, input):
|
| 34 |
+
location_features = torch.zeros(input.shape[0], 512).to(self.device)
|
| 35 |
+
|
| 36 |
+
for i in range(self.num_hierarchies):
|
| 37 |
+
location_features += self._modules['LocEnc' + str(i)](input)
|
| 38 |
+
return location_features
|
src/g3/nn/siren.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
# helpers
|
| 8 |
+
|
| 9 |
+
def exists(val):
|
| 10 |
+
return val is not None
|
| 11 |
+
|
| 12 |
+
def cast_tuple(val, repeat = 1):
|
| 13 |
+
return val if isinstance(val, tuple) else ((val,) * repeat)
|
| 14 |
+
|
| 15 |
+
# sin activation
|
| 16 |
+
|
| 17 |
+
class Sine(nn.Module):
|
| 18 |
+
def __init__(self, w0 = 1.):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.w0 = w0
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return torch.sin(self.w0 * x)
|
| 23 |
+
|
| 24 |
+
# siren layer
|
| 25 |
+
|
| 26 |
+
class Siren(nn.Module):
|
| 27 |
+
def __init__(self, input_dim, output_dim, w0 = 1., c = 6., is_first = False, use_bias = True, activation = None, dropout = False):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.input_dim = input_dim
|
| 30 |
+
self.is_first = is_first
|
| 31 |
+
self.output_dim = output_dim
|
| 32 |
+
self.dropout = dropout
|
| 33 |
+
|
| 34 |
+
weight = torch.zeros(output_dim, input_dim)
|
| 35 |
+
bias = torch.zeros(output_dim) if use_bias else None
|
| 36 |
+
self.init_(weight, bias, c = c, w0 = w0)
|
| 37 |
+
|
| 38 |
+
self.weight = nn.Parameter(weight)
|
| 39 |
+
self.bias = nn.Parameter(bias) if use_bias else None
|
| 40 |
+
self.activation = Sine(w0) if activation is None else activation
|
| 41 |
+
|
| 42 |
+
def init_(self, weight, bias, c, w0):
|
| 43 |
+
dim = self.input_dim
|
| 44 |
+
|
| 45 |
+
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
|
| 46 |
+
weight.uniform_(-w_std, w_std)
|
| 47 |
+
|
| 48 |
+
if exists(bias):
|
| 49 |
+
bias.uniform_(-w_std, w_std)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
out = F.linear(x, self.weight, self.bias)
|
| 53 |
+
if self.dropout:
|
| 54 |
+
out = F.dropout(out, training=self.training)
|
| 55 |
+
out = self.activation(out)
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
# siren network
|
| 59 |
+
|
| 60 |
+
class SirenNet(nn.Module):
|
| 61 |
+
def __init__(self, input_dim = 512, hidden_dim = 1024, output_dim = 512, num_layers = 3, w0 = 1., w0_initial = 30., use_bias = True, final_activation = None, degreeinput = False, dropout = False, hparams=None, device='cuda'):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.num_layers = num_layers
|
| 64 |
+
self.hidden_dim = hidden_dim
|
| 65 |
+
self.degreeinput = degreeinput
|
| 66 |
+
self.device = device
|
| 67 |
+
|
| 68 |
+
self.layers = nn.ModuleList([])
|
| 69 |
+
for ind in range(num_layers):
|
| 70 |
+
is_first = ind == 0
|
| 71 |
+
layer_w0 = w0_initial if is_first else w0
|
| 72 |
+
layer_input_dim = input_dim if is_first else hidden_dim
|
| 73 |
+
|
| 74 |
+
self.layers.append(Siren(
|
| 75 |
+
input_dim = layer_input_dim,
|
| 76 |
+
output_dim = hidden_dim,
|
| 77 |
+
w0 = layer_w0,
|
| 78 |
+
use_bias = use_bias,
|
| 79 |
+
is_first = is_first,
|
| 80 |
+
dropout = dropout
|
| 81 |
+
))
|
| 82 |
+
|
| 83 |
+
final_activation = nn.Identity() if not exists(final_activation) else final_activation
|
| 84 |
+
self.last_layer = Siren(input_dim = hidden_dim, output_dim = output_dim, w0 = w0, use_bias = use_bias, activation = final_activation, dropout = False)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, mods = None):
|
| 87 |
+
|
| 88 |
+
# do some normalization to bring degrees in a -pi to pi range
|
| 89 |
+
if self.degreeinput:
|
| 90 |
+
x = torch.deg2rad(x) - torch.pi
|
| 91 |
+
|
| 92 |
+
mods = cast_tuple(mods, self.num_layers)
|
| 93 |
+
|
| 94 |
+
for layer, mod in zip(self.layers, mods):
|
| 95 |
+
x = layer(x)
|
| 96 |
+
|
| 97 |
+
if exists(mod):
|
| 98 |
+
x *= rearrange(mod, 'd -> () d')
|
| 99 |
+
|
| 100 |
+
return self.last_layer(x)
|
src/g3/pe/projection.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import itertools
|
| 6 |
+
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel
|
| 7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 8 |
+
from pyproj import Proj, Transformer
|
| 9 |
+
|
| 10 |
+
SF = 66.50336
|
| 11 |
+
|
| 12 |
+
class Projection(nn.Module):
|
| 13 |
+
def __init__(self, projection="mercator", hparams=None, device='cuda'):
|
| 14 |
+
super(Projection, self).__init__()
|
| 15 |
+
self.device = device
|
| 16 |
+
self.projection = projection.lower()
|
| 17 |
+
|
| 18 |
+
proj_wgs84 = Proj('epsg:4326')
|
| 19 |
+
|
| 20 |
+
if self.projection == "mercator":
|
| 21 |
+
proj_target = Proj('epsg:3857')
|
| 22 |
+
self.normalizer = 20037508.3427892
|
| 23 |
+
self.embedding_dim = [2]
|
| 24 |
+
elif self.projection == "eep":
|
| 25 |
+
proj_target = Proj('epsg:8857')
|
| 26 |
+
self.normalizer = 180/SF
|
| 27 |
+
self.embedding_dim = [2]
|
| 28 |
+
elif self.projection == "ecef":
|
| 29 |
+
proj_target = Proj('epsg:4978')
|
| 30 |
+
self.normalizer = 6378137.0 # radius of Earth, not exact for ECEF but usable
|
| 31 |
+
self.embedding_dim = [3]
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unsupported projection: {self.projection}")
|
| 34 |
+
|
| 35 |
+
self.transformer = Transformer.from_proj(proj_wgs84, proj_target, always_xy=True)
|
| 36 |
+
|
| 37 |
+
def forward(self, input):
|
| 38 |
+
lat = input[:, 0].float().detach().cpu().numpy()
|
| 39 |
+
lon = input[:, 1].float().detach().cpu().numpy()
|
| 40 |
+
# lon (batch), lat (batch)
|
| 41 |
+
|
| 42 |
+
# Shape: (batch, 2) or (batch, 3) depending on projection
|
| 43 |
+
if self.projection == "ecef":
|
| 44 |
+
alt = np.zeros_like(lat)
|
| 45 |
+
projected = self.transformer.transform(lon, lat, alt)
|
| 46 |
+
location = list(zip(*projected)) # X, Y, Z
|
| 47 |
+
location = torch.Tensor(location).to(self.device)
|
| 48 |
+
else:
|
| 49 |
+
projected = self.transformer.transform(lon, lat)
|
| 50 |
+
location = [[y, x] for x, y in zip(*projected)]
|
| 51 |
+
location = torch.Tensor(location).to(self.device)
|
| 52 |
+
|
| 53 |
+
location = location / self.normalizer
|
| 54 |
+
return location
|
src/g3/pe/projection_rff.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import itertools
|
| 6 |
+
from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPModel
|
| 7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 8 |
+
from ..rff.layers import GaussianEncoding
|
| 9 |
+
from pyproj import Proj, Transformer
|
| 10 |
+
|
| 11 |
+
SF = 66.50336
|
| 12 |
+
|
| 13 |
+
class ProjectionRFF(nn.Module):
|
| 14 |
+
def __init__(self, projection="ecef", sigma=[2**0, 2**4, 2**8], hparams=None, device='cuda'):
|
| 15 |
+
super(ProjectionRFF, self).__init__()
|
| 16 |
+
self.device = device
|
| 17 |
+
self.sigma = sigma
|
| 18 |
+
self.num_hierarchies = len(self.sigma)
|
| 19 |
+
self.projection = projection.lower()
|
| 20 |
+
self.embedding_dim = [512] * self.num_hierarchies
|
| 21 |
+
|
| 22 |
+
proj_wgs84 = Proj('epsg:4326')
|
| 23 |
+
if self.projection == "mercator":
|
| 24 |
+
proj_target = Proj('epsg:3857')
|
| 25 |
+
input_dim = 2
|
| 26 |
+
self.normalizer = 20037508.3427892
|
| 27 |
+
elif self.projection == "eep":
|
| 28 |
+
proj_target = Proj('epsg:8857')
|
| 29 |
+
input_dim = 2
|
| 30 |
+
self.normalizer = 180/SF
|
| 31 |
+
elif self.projection == "ecef":
|
| 32 |
+
proj_target = Proj('epsg:4978')
|
| 33 |
+
input_dim = 3
|
| 34 |
+
self.normalizer = 6378137.0 # radius of Earth, not exact for ECEF but usable
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f"Unsupported projection: {self.projection}")
|
| 37 |
+
|
| 38 |
+
self.transformer = Transformer.from_proj(proj_wgs84, proj_target, always_xy=True)
|
| 39 |
+
for i, s in enumerate(self.sigma):
|
| 40 |
+
self.add_module('LocEnc' + str(i), GaussianEncoding(sigma=s, input_size=input_dim, encoded_size=256))
|
| 41 |
+
|
| 42 |
+
def forward(self, input):
|
| 43 |
+
lat = input[:, 0].float().detach().cpu().numpy()
|
| 44 |
+
lon = input[:, 1].float().detach().cpu().numpy()
|
| 45 |
+
# lon (batch), lat (batch)
|
| 46 |
+
|
| 47 |
+
# Shape: (batch, 2) or (batch, 3) depending on projection
|
| 48 |
+
if self.projection == "ecef":
|
| 49 |
+
alt = np.zeros_like(lat)
|
| 50 |
+
projected = self.transformer.transform(lon, lat, alt)
|
| 51 |
+
location = list(zip(*projected)) # X, Y, Z
|
| 52 |
+
location = torch.Tensor(location).to(self.device)
|
| 53 |
+
else:
|
| 54 |
+
projected = self.transformer.transform(lon, lat)
|
| 55 |
+
location = [[y, x] for x, y in zip(*projected)]
|
| 56 |
+
location = torch.Tensor(location).to(self.device)
|
| 57 |
+
|
| 58 |
+
location = location / self.normalizer
|
| 59 |
+
out = []
|
| 60 |
+
|
| 61 |
+
for i in range(self.num_hierarchies):
|
| 62 |
+
out.append(self._modules['LocEnc' + str(i)](location))
|
| 63 |
+
|
| 64 |
+
location_features = torch.stack(out, dim=0) # (hierarchies, batch, 512)
|
| 65 |
+
return location_features
|
src/g3/pe/spherical_harmonics.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .spherical_harmonics_ylm import SH as SH_analytic
|
| 4 |
+
from .spherical_harmonics_closed_form import SH as SH_closed_form
|
| 5 |
+
|
| 6 |
+
class SphericalHarmonics(nn.Module):
|
| 7 |
+
def __init__(self, legendre_polys: int = 20, harmonics_calculation="analytic", hparams=None, device='cuda'):
|
| 8 |
+
"""
|
| 9 |
+
legendre_polys: determines the number of legendre polynomials.
|
| 10 |
+
more polynomials lead more fine-grained resolutions
|
| 11 |
+
calculation of spherical harmonics:
|
| 12 |
+
analytic uses pre-computed equations. This is exact, but works only up to degree 50,
|
| 13 |
+
closed-form uses one equation but is computationally slower (especially for high degrees)
|
| 14 |
+
"""
|
| 15 |
+
super(SphericalHarmonics, self).__init__()
|
| 16 |
+
self.device = device
|
| 17 |
+
self.L, self.M = int(legendre_polys), int(legendre_polys)
|
| 18 |
+
self.embedding_dim = [self.L * self.M]
|
| 19 |
+
|
| 20 |
+
if harmonics_calculation == "closed-form":
|
| 21 |
+
self.SH = SH_closed_form
|
| 22 |
+
elif harmonics_calculation == "analytic":
|
| 23 |
+
self.SH = SH_analytic
|
| 24 |
+
|
| 25 |
+
def forward(self, lonlat):
|
| 26 |
+
lon, lat = lonlat[:, 0], lonlat[:, 1] # lon: (batch), lat: (batch)
|
| 27 |
+
|
| 28 |
+
# convert degree to rad
|
| 29 |
+
phi = torch.deg2rad(lon + 180)
|
| 30 |
+
theta = torch.deg2rad(lat + 90)
|
| 31 |
+
|
| 32 |
+
Y = [] # (L * L, batch)
|
| 33 |
+
for l in range(self.L):
|
| 34 |
+
for m in range(-l, l + 1):
|
| 35 |
+
y = self.SH(m, l, phi, theta)
|
| 36 |
+
if isinstance(y, float):
|
| 37 |
+
y = y * torch.ones_like(phi)
|
| 38 |
+
Y.append(y)
|
| 39 |
+
|
| 40 |
+
return torch.stack(Y,dim=-1).detach() # (batch, L * L)
|
src/g3/pe/spherical_harmonics_closed_form.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
####################### Spherical Harmonics utilities ########################
|
| 5 |
+
# Code copied from https://github.com/BachiLi/redner/blob/master/pyredner/utils.py
|
| 6 |
+
# Code adapted from "Spherical Harmonic Lighting: The Gritty Details", Robin Green
|
| 7 |
+
# http://silviojemma.com/public/papers/lighting/spherical-harmonic-lighting.pdf
|
| 8 |
+
def associated_legendre_polynomial(l, m, x):
|
| 9 |
+
pmm = torch.ones_like(x)
|
| 10 |
+
if m > 0:
|
| 11 |
+
somx2 = torch.sqrt((1 - x) * (1 + x))
|
| 12 |
+
fact = 1.0
|
| 13 |
+
for i in range(1, m + 1):
|
| 14 |
+
pmm = pmm * (-fact) * somx2
|
| 15 |
+
fact += 2.0
|
| 16 |
+
if l == m:
|
| 17 |
+
return pmm
|
| 18 |
+
pmmp1 = x * (2.0 * m + 1.0) * pmm
|
| 19 |
+
if l == m + 1:
|
| 20 |
+
return pmmp1
|
| 21 |
+
pll = torch.zeros_like(x)
|
| 22 |
+
for ll in range(m + 2, l + 1):
|
| 23 |
+
pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
|
| 24 |
+
pmm = pmmp1
|
| 25 |
+
pmmp1 = pll
|
| 26 |
+
return pll
|
| 27 |
+
|
| 28 |
+
def SH_renormalization(l, m):
|
| 29 |
+
return math.sqrt((2.0 * l + 1.0) * math.factorial(l - m) / \
|
| 30 |
+
(4 * math.pi * math.factorial(l + m)))
|
| 31 |
+
|
| 32 |
+
def SH(m, l, phi, theta):
|
| 33 |
+
if m == 0:
|
| 34 |
+
return SH_renormalization(l, m) * associated_legendre_polynomial(l, m, torch.cos(theta))
|
| 35 |
+
elif m > 0:
|
| 36 |
+
return math.sqrt(2.0) * SH_renormalization(l, m) * \
|
| 37 |
+
torch.cos(m * phi) * associated_legendre_polynomial(l, m, torch.cos(theta))
|
| 38 |
+
else:
|
| 39 |
+
return math.sqrt(2.0) * SH_renormalization(l, -m) * \
|
| 40 |
+
torch.sin(-m * phi) * associated_legendre_polynomial(l, -m, torch.cos(theta))
|
src/g3/pe/spherical_harmonics_generate_ylms.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This function prints the source code for spherical_harmonics_ylms.py to console
|
| 3 |
+
|
| 4 |
+
spherical_harmonics pre-computes the analytical solutions to each real spherical harmonic with sympy
|
| 5 |
+
the script contains different functions for different degrees l and orders m
|
| 6 |
+
|
| 7 |
+
Marc Russwurm
|
| 8 |
+
"""
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
from sympy import assoc_legendre
|
| 13 |
+
from sympy import cos, sin, sqrt, pi, factorial, Abs
|
| 14 |
+
from sympy import Symbol
|
| 15 |
+
|
| 16 |
+
theta = Symbol("theta")
|
| 17 |
+
phi = Symbol("phi")
|
| 18 |
+
|
| 19 |
+
def calc_ylm(l, m):
|
| 20 |
+
"""
|
| 21 |
+
see last equation of https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
|
| 22 |
+
"""
|
| 23 |
+
if m < 0:
|
| 24 |
+
Plm = assoc_legendre(l, Abs(m), cos(theta))
|
| 25 |
+
Plm_bar = sqrt(((2 * l + 1) / (4 * pi)) * (factorial(l - Abs(m)) / factorial(l + Abs(m)))) * Plm
|
| 26 |
+
|
| 27 |
+
Ylm = (-1)**m * sqrt(2) * Plm_bar * sin(Abs(m) * phi)
|
| 28 |
+
elif m == 0:
|
| 29 |
+
Ylm = sqrt((2*l + 1) / (4 * pi)) * assoc_legendre(l, m, cos(theta))
|
| 30 |
+
else: # m > 0
|
| 31 |
+
Plm = assoc_legendre(l, m, cos(theta))
|
| 32 |
+
Plm_bar = sqrt(((2 * l + 1) / (4 * pi)) * (factorial(l - m) / factorial(l + m))) * Plm
|
| 33 |
+
|
| 34 |
+
Ylm = (-1)**m * sqrt(2) * Plm_bar * cos(m * phi)
|
| 35 |
+
return Ylm
|
| 36 |
+
|
| 37 |
+
def print_function(l, m):
|
| 38 |
+
fname = f"Yl{l}_m{m}".replace("-", "_minus_")
|
| 39 |
+
print()
|
| 40 |
+
print("@torch.jit.script")
|
| 41 |
+
print(f"def {fname}(theta, phi):")
|
| 42 |
+
print(" return " + str(calc_ylm(l, m).evalf()))
|
| 43 |
+
|
| 44 |
+
# max number of Legendre Polynomials
|
| 45 |
+
L = 101
|
| 46 |
+
|
| 47 |
+
head = """\"\"\"
|
| 48 |
+
analytic expressions of spherical harmonics generated with sympy file
|
| 49 |
+
Marc Russwurm generated """ + str(datetime.date(datetime.now())) + """
|
| 50 |
+
|
| 51 |
+
run
|
| 52 |
+
python """ + sys.argv[0] + """ > spherical_harmonics_ylm.py
|
| 53 |
+
|
| 54 |
+
to generate the source code
|
| 55 |
+
\"\"\"
|
| 56 |
+
|
| 57 |
+
import torch
|
| 58 |
+
from torch import cos, sin
|
| 59 |
+
|
| 60 |
+
def get_SH(m,l):
|
| 61 |
+
fname = f"Yl{l}_m{m}".replace("-","_minus_")
|
| 62 |
+
return globals()[fname]
|
| 63 |
+
|
| 64 |
+
def SH(m, l, phi, theta):
|
| 65 |
+
Ylm = get_SH(m,l)
|
| 66 |
+
return Ylm(theta, phi)
|
| 67 |
+
"""
|
| 68 |
+
print(head)
|
| 69 |
+
print()
|
| 70 |
+
|
| 71 |
+
for l in range(L):
|
| 72 |
+
for m in range(-l,l+1):
|
| 73 |
+
print_function(l,m)
|
src/g3/pe/spherical_harmonics_ylm.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/g3/rff/functional.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def sample_b(sigma: float, size: tuple) -> Tensor:
|
| 8 |
+
r"""Matrix of size :attr:`size` sampled from from :math:`\mathcal{N}(0, \sigma^2)`
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
sigma (float): standard deviation
|
| 12 |
+
size (tuple): size of the matrix sampled
|
| 13 |
+
|
| 14 |
+
See :class:`~rff.layers.GaussianEncoding` for more details
|
| 15 |
+
"""
|
| 16 |
+
return torch.randn(size) * sigma
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@torch.jit.script
|
| 20 |
+
def gaussian_encoding(
|
| 21 |
+
v: Tensor,
|
| 22 |
+
b: Tensor) -> Tensor:
|
| 23 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})`
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 27 |
+
b (Tensor): projection matrix of shape :math:`(\text{encoded_layer_size}, \text{input_size})`
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{encoded_layer_size})`
|
| 31 |
+
|
| 32 |
+
See :class:`~rff.layers.GaussianEncoding` for more details.
|
| 33 |
+
"""
|
| 34 |
+
vp = 2 * np.pi * v @ b.T
|
| 35 |
+
return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@torch.jit.script
|
| 39 |
+
def basic_encoding(
|
| 40 |
+
v: Tensor) -> Tensor:
|
| 41 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})`
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})`
|
| 48 |
+
|
| 49 |
+
See :class:`~rff.layers.BasicEncoding` for more details.
|
| 50 |
+
"""
|
| 51 |
+
vp = 2 * np.pi * v
|
| 52 |
+
return torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@torch.jit.script
|
| 56 |
+
def positional_encoding(
|
| 57 |
+
v: Tensor,
|
| 58 |
+
sigma: float,
|
| 59 |
+
m: int) -> Tensor:
|
| 60 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
|
| 61 |
+
where :math:`j \in \{0, \dots, m-1\}`
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 65 |
+
sigma (float): constant chosen based upon the domain of :attr:`v`
|
| 66 |
+
m (int): [description]
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
|
| 70 |
+
|
| 71 |
+
See :class:`~rff.layers.PositionalEncoding` for more details.
|
| 72 |
+
"""
|
| 73 |
+
j = torch.arange(m, device=v.device)
|
| 74 |
+
coeffs = 2 * np.pi * sigma ** (j / m)
|
| 75 |
+
vp = coeffs * torch.unsqueeze(v, -1)
|
| 76 |
+
vp_cat = torch.cat((torch.cos(vp), torch.sin(vp)), dim=-1)
|
| 77 |
+
return vp_cat.flatten(-2, -1)
|
src/g3/rff/layers.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from . import functional
|
| 6 |
+
|
| 7 |
+
class GaussianEncoding(nn.Module):
|
| 8 |
+
"""Layer for mapping coordinates using random Fourier features"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, sigma: Optional[float] = None,
|
| 11 |
+
input_size: Optional[float] = None,
|
| 12 |
+
encoded_size: Optional[float] = None,
|
| 13 |
+
b: Optional[Tensor] = None):
|
| 14 |
+
r"""
|
| 15 |
+
Args:
|
| 16 |
+
sigma (Optional[float]): standard deviation
|
| 17 |
+
input_size (Optional[float]): the number of input dimensions
|
| 18 |
+
encoded_size (Optional[float]): the number of dimensions the `b` matrix maps to
|
| 19 |
+
b (Optional[Tensor], optional): Optionally specify a :attr:`b` matrix already sampled
|
| 20 |
+
Raises:
|
| 21 |
+
ValueError:
|
| 22 |
+
If :attr:`b` is provided and one of :attr:`sigma`, :attr:`input_size`,
|
| 23 |
+
or :attr:`encoded_size` is provided. If :attr:`b` is not provided and one of
|
| 24 |
+
:attr:`sigma`, :attr:`input_size`, or :attr:`encoded_size` is not provided.
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
if b is None:
|
| 28 |
+
if sigma is None or input_size is None or encoded_size is None:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
'Arguments "sigma," "input_size," and "encoded_size" are required.')
|
| 31 |
+
|
| 32 |
+
b = functional.sample_b(sigma, (encoded_size, input_size))
|
| 33 |
+
elif sigma is not None or input_size is not None or encoded_size is not None:
|
| 34 |
+
raise ValueError('Only specify the "b" argument when using it.')
|
| 35 |
+
self.b = nn.parameter.Parameter(b, requires_grad=False)
|
| 36 |
+
|
| 37 |
+
def forward(self, v: Tensor) -> Tensor:
|
| 38 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{B} \mathbf{v}} , \sin{2 \pi \mathbf{B} \mathbf{v}})`
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tensor: Tensor mapping using random fourier features of shape :math:`(N, *, 2 \cdot \text{encoded_size})`
|
| 45 |
+
"""
|
| 46 |
+
return functional.gaussian_encoding(v, self.b)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BasicEncoding(nn.Module):
|
| 50 |
+
"""Layer for mapping coordinates using the basic encoding"""
|
| 51 |
+
|
| 52 |
+
def forward(self, v: Tensor) -> Tensor:
|
| 53 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\cos{2 \pi \mathbf{v}} , \sin{2 \pi \mathbf{v}})`
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot \text{input_size})`
|
| 60 |
+
"""
|
| 61 |
+
return functional.basic_encoding(v)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PositionalEncoding(nn.Module):
|
| 65 |
+
"""Layer for mapping coordinates using the positional encoding"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, sigma: float, m: int):
|
| 68 |
+
r"""
|
| 69 |
+
Args:
|
| 70 |
+
sigma (float): frequency constant
|
| 71 |
+
m (int): number of frequencies to map to
|
| 72 |
+
"""
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.sigma = sigma
|
| 75 |
+
self.m = m
|
| 76 |
+
|
| 77 |
+
def forward(self, v: Tensor) -> Tensor:
|
| 78 |
+
r"""Computes :math:`\gamma(\mathbf{v}) = (\dots, \cos{2 \pi \sigma^{(j/m)} \mathbf{v}} , \sin{2 \pi \sigma^{(j/m)} \mathbf{v}}, \dots)`
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
v (Tensor): input tensor of shape :math:`(N, *, \text{input_size})`
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Tensor: mapped tensor of shape :math:`(N, *, 2 \cdot m \cdot \text{input_size})`
|
| 85 |
+
"""
|
| 86 |
+
return functional.positional_encoding(v, self.sigma, self.m)
|
src/g3_batch_prediction.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import time
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import yaml
|
| 10 |
+
from google import genai
|
| 11 |
+
from google.genai import types
|
| 12 |
+
from pydantic import ValidationError
|
| 13 |
+
from tqdm.asyncio import tqdm as atqdm
|
| 14 |
+
|
| 15 |
+
from .data_processor import DataProcessor
|
| 16 |
+
from .g3.G3 import G3
|
| 17 |
+
from .prompt import (
|
| 18 |
+
Evidence,
|
| 19 |
+
GPSPrediction,
|
| 20 |
+
LocationPrediction,
|
| 21 |
+
diversification_prompt,
|
| 22 |
+
location_prompt,
|
| 23 |
+
verification_prompt,
|
| 24 |
+
)
|
| 25 |
+
from .utils import (
|
| 26 |
+
calculate_similarity_scores,
|
| 27 |
+
extract_and_parse_json,
|
| 28 |
+
get_gps_from_location,
|
| 29 |
+
handle_async_api_call_with_retry,
|
| 30 |
+
image_to_base64,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger("uvicorn.error")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class G3BatchPredictor:
|
| 37 |
+
"""
|
| 38 |
+
Batch prediction class for processing all images and videos in a directory.
|
| 39 |
+
|
| 40 |
+
This class:
|
| 41 |
+
1. Preprocesses all images and videos in a directory.
|
| 42 |
+
2. Extracts keyframes from videos and combines them with images.
|
| 43 |
+
3. Passes all keyframes and images to the Gemini model for prediction.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
device: str = "cuda",
|
| 49 |
+
input_dir: str = "data/input_data",
|
| 50 |
+
prompt_dir: str = "data/prompt_data",
|
| 51 |
+
cache_dir: str = "data/cache",
|
| 52 |
+
index_path: str = "data/index/G3.index",
|
| 53 |
+
hparams_path: str = "g3/hparams.yaml",
|
| 54 |
+
database_csv_path: str = "data/dataset/mp16/MP16_Pro_filtered.csv",
|
| 55 |
+
checkpoint_path: str = "data/checkpoints/mercator_finetune_weight.pth",
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize the BatchKeyframePredictor.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
checkpoint_path (str): Path to G3 model checkpoint
|
| 62 |
+
device (str): Device to run model on ("cuda" or "cpu")
|
| 63 |
+
index_path (str): Path to FAISS index for RAG (required)
|
| 64 |
+
"""
|
| 65 |
+
self.device = torch.device(device)
|
| 66 |
+
self.base_path = Path(__file__).parent
|
| 67 |
+
self.checkpoint_path = self.base_path / checkpoint_path
|
| 68 |
+
|
| 69 |
+
self.input_dir = self.base_path / input_dir
|
| 70 |
+
self.prompt_dir = self.base_path / prompt_dir
|
| 71 |
+
self.cache_dir = self.base_path / cache_dir
|
| 72 |
+
self.image_dir = self.prompt_dir / "images"
|
| 73 |
+
self.audio_dir = self.prompt_dir / "audio"
|
| 74 |
+
|
| 75 |
+
os.makedirs(self.input_dir, exist_ok=True)
|
| 76 |
+
os.makedirs(self.prompt_dir, exist_ok=True)
|
| 77 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 78 |
+
os.makedirs(self.image_dir, exist_ok=True)
|
| 79 |
+
os.makedirs(self.audio_dir, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
# Initialize G3 model
|
| 82 |
+
hparams = yaml.safe_load(open(self.base_path / hparams_path, "r"))
|
| 83 |
+
pe = "projection_mercator"
|
| 84 |
+
nn = "rffmlp"
|
| 85 |
+
|
| 86 |
+
self.model = G3(
|
| 87 |
+
device=device,
|
| 88 |
+
positional_encoding_type=pe,
|
| 89 |
+
neural_network_type=nn,
|
| 90 |
+
hparams=hparams[f"{pe}_{nn}"],
|
| 91 |
+
)
|
| 92 |
+
self.__load_checkpoint()
|
| 93 |
+
|
| 94 |
+
self.data_processor = DataProcessor(
|
| 95 |
+
model=self.model,
|
| 96 |
+
input_dir=self.input_dir,
|
| 97 |
+
prompt_dir=self.prompt_dir,
|
| 98 |
+
cache_dir=self.cache_dir,
|
| 99 |
+
image_dir=self.image_dir,
|
| 100 |
+
audio_dir=self.audio_dir,
|
| 101 |
+
index_path=self.base_path / index_path,
|
| 102 |
+
database_csv_path=self.base_path / database_csv_path,
|
| 103 |
+
device=self.device,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.image_extension = {
|
| 107 |
+
".jpg",
|
| 108 |
+
".jpeg",
|
| 109 |
+
".png",
|
| 110 |
+
".bmp",
|
| 111 |
+
".tiff",
|
| 112 |
+
".tif",
|
| 113 |
+
".webp",
|
| 114 |
+
}
|
| 115 |
+
self.video_extension = {
|
| 116 |
+
".mp4",
|
| 117 |
+
".avi",
|
| 118 |
+
".mov",
|
| 119 |
+
".mkv",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
def __load_checkpoint(self):
|
| 123 |
+
"""
|
| 124 |
+
Load the G3 model checkpoint.
|
| 125 |
+
"""
|
| 126 |
+
if not os.path.exists(self.checkpoint_path):
|
| 127 |
+
raise FileNotFoundError(
|
| 128 |
+
f"Checkpoint file not found: {self.checkpoint_path}"
|
| 129 |
+
)
|
| 130 |
+
self.model.load_state_dict(
|
| 131 |
+
torch.load(self.checkpoint_path, map_location=self.device)
|
| 132 |
+
)
|
| 133 |
+
self.model.to(self.device)
|
| 134 |
+
self.model.eval()
|
| 135 |
+
logger.info(
|
| 136 |
+
f"✅ Successfully loaded G3 model checkpoint from: {self.checkpoint_path}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
async def llm_predict(
|
| 140 |
+
self,
|
| 141 |
+
model_name: str = "gemini-2.5-pro",
|
| 142 |
+
n_search: int | None = None,
|
| 143 |
+
n_coords: int | None = None,
|
| 144 |
+
image_prediction: bool = True,
|
| 145 |
+
text_prediction: bool = True,
|
| 146 |
+
) -> LocationPrediction:
|
| 147 |
+
"""
|
| 148 |
+
Generate a prediction using the Gemini LLM with Pydantic structured output.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model_name: LLM model name to use
|
| 152 |
+
n_search: Number of search results to include
|
| 153 |
+
n_coords: Number of coordinates to include
|
| 154 |
+
image_prediction: Whether to use images in prediction
|
| 155 |
+
text_prediction: Whether to use text in prediction
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
dict: Parsed prediction response
|
| 159 |
+
"""
|
| 160 |
+
prompt = diversification_prompt(
|
| 161 |
+
prompt_dir=str(self.prompt_dir),
|
| 162 |
+
n_coords=n_coords,
|
| 163 |
+
n_search=n_search,
|
| 164 |
+
image_prediction=image_prediction,
|
| 165 |
+
text_prediction=text_prediction,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
images = []
|
| 169 |
+
if image_prediction:
|
| 170 |
+
image_dir = self.image_dir
|
| 171 |
+
if not image_dir.exists():
|
| 172 |
+
raise ValueError(f"Image directory does not exist: {image_dir}")
|
| 173 |
+
|
| 174 |
+
for image_file in image_dir.glob("*.jpg"):
|
| 175 |
+
with open(image_file, "rb") as f:
|
| 176 |
+
image = types.Part.from_bytes(data=f.read(), mime_type="image/jpeg")
|
| 177 |
+
images.append(image)
|
| 178 |
+
|
| 179 |
+
client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
|
| 180 |
+
|
| 181 |
+
async def api_call():
|
| 182 |
+
loop = asyncio.get_event_loop()
|
| 183 |
+
response = await loop.run_in_executor(
|
| 184 |
+
None,
|
| 185 |
+
lambda: client.models.generate_content(
|
| 186 |
+
model=model_name,
|
| 187 |
+
contents=[*images, prompt],
|
| 188 |
+
config=types.GenerateContentConfig(
|
| 189 |
+
tools=[
|
| 190 |
+
types.Tool(url_context=types.UrlContext()),
|
| 191 |
+
],
|
| 192 |
+
temperature=0.1,
|
| 193 |
+
top_p=0.95,
|
| 194 |
+
),
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
raw_text = response.text.strip() if response.text is not None else ""
|
| 199 |
+
parsed_json = extract_and_parse_json(raw_text)
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
validated = LocationPrediction.model_validate(parsed_json)
|
| 203 |
+
return validated
|
| 204 |
+
except (ValidationError, ValueError):
|
| 205 |
+
raise ValueError("Empty or invalid LLM response")
|
| 206 |
+
|
| 207 |
+
return await handle_async_api_call_with_retry(
|
| 208 |
+
api_call,
|
| 209 |
+
fallback_result=LocationPrediction(
|
| 210 |
+
latitude=0.0, longitude=0.0, location="", evidence=[]
|
| 211 |
+
),
|
| 212 |
+
error_context=f"LLM prediction with {model_name}",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
async def diversification_predict(
|
| 216 |
+
self,
|
| 217 |
+
model_name: str = "gemini-2.5-flash",
|
| 218 |
+
image_prediction: bool = True,
|
| 219 |
+
text_prediction: bool = True,
|
| 220 |
+
) -> LocationPrediction:
|
| 221 |
+
"""
|
| 222 |
+
Diversification prediction without preprocessing (assumes preprocessing already done).
|
| 223 |
+
Runs different sample sizes in parallel for faster execution.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_name (str): LLM model name to use
|
| 227 |
+
image_prediction (bool): Whether to use images in prediction
|
| 228 |
+
text_prediction (bool): Whether to use text in prediction
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
dict: Best prediction with latitude, longitude, location, reason, and metadata
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
# Function to try a specific sample size with retry logic
|
| 235 |
+
async def try_sample_size(num_sample):
|
| 236 |
+
while True:
|
| 237 |
+
prediction = await self.llm_predict(
|
| 238 |
+
model_name=model_name,
|
| 239 |
+
n_search=num_sample,
|
| 240 |
+
n_coords=num_sample,
|
| 241 |
+
image_prediction=image_prediction,
|
| 242 |
+
text_prediction=text_prediction,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if prediction:
|
| 246 |
+
coords = (prediction.latitude, prediction.longitude)
|
| 247 |
+
return (num_sample, coords, prediction)
|
| 248 |
+
else:
|
| 249 |
+
logger.info(
|
| 250 |
+
f"Invalid or empty prediction format with {num_sample} samples, retrying..."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Run all sample sizes in parallel
|
| 254 |
+
num_samples = [10, 15, 20]
|
| 255 |
+
logger.info(
|
| 256 |
+
f"🚀 Running {len(num_samples)} sample sizes in parallel: {num_samples}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
tasks = [try_sample_size(num_sample) for num_sample in num_samples]
|
| 260 |
+
|
| 261 |
+
class LW:
|
| 262 |
+
def write(self, msg: str) -> int:
|
| 263 |
+
logger.info(msg)
|
| 264 |
+
return len(msg)
|
| 265 |
+
|
| 266 |
+
def flush(self):
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
results = await atqdm.gather(
|
| 270 |
+
*tasks,
|
| 271 |
+
desc="🔄 Running diversification predictions",
|
| 272 |
+
file=LW(),
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Build predictions dictionary from parallel results
|
| 276 |
+
predictions_dict = {}
|
| 277 |
+
for num_sample, coords, prediction in results:
|
| 278 |
+
predictions_dict[coords] = prediction
|
| 279 |
+
logger.info(f"✅ Collected prediction with {num_sample} samples: {coords}")
|
| 280 |
+
|
| 281 |
+
# Convert predictions to coordinate list for similarity scoring
|
| 282 |
+
predicted_coords = list(predictions_dict.keys())
|
| 283 |
+
logger.info(f"Predicted coordinates: {predicted_coords}")
|
| 284 |
+
|
| 285 |
+
if not predicted_coords:
|
| 286 |
+
raise ValueError("No valid predictions obtained from any sample size")
|
| 287 |
+
|
| 288 |
+
# Calculate similarity scores
|
| 289 |
+
avg_similarities = calculate_similarity_scores(
|
| 290 |
+
model=self.model,
|
| 291 |
+
device=self.device,
|
| 292 |
+
predicted_coords=predicted_coords,
|
| 293 |
+
image_dir=self.image_dir,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Find best prediction
|
| 297 |
+
best_idx = np.argmax(avg_similarities)
|
| 298 |
+
best_coords = predicted_coords[best_idx]
|
| 299 |
+
best_prediction = predictions_dict[best_coords]
|
| 300 |
+
|
| 301 |
+
logger.info(f"🎯 Best prediction selected: {best_coords}")
|
| 302 |
+
logger.info(f" Similarity scores: {avg_similarities}")
|
| 303 |
+
logger.info(f" Best index: {best_idx}")
|
| 304 |
+
|
| 305 |
+
# print(json.dumps(best_prediction, indent=2)) # Commented out verbose output
|
| 306 |
+
|
| 307 |
+
return best_prediction
|
| 308 |
+
|
| 309 |
+
async def location_predict(
|
| 310 |
+
self,
|
| 311 |
+
model_name: str = "gemini-2.5-flash",
|
| 312 |
+
location: str = "specified location",
|
| 313 |
+
) -> GPSPrediction:
|
| 314 |
+
"""
|
| 315 |
+
Generate a location-based prediction using the Gemini LLM with centralized retry logic.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
model_name (str): LLM model name to use
|
| 319 |
+
location (str): Location to use in the prompt
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
dict: Parsed JSON prediction response
|
| 323 |
+
"""
|
| 324 |
+
if not location:
|
| 325 |
+
raise ValueError("Location must be specified for location-based prediction")
|
| 326 |
+
|
| 327 |
+
lat, lon = get_gps_from_location(location)
|
| 328 |
+
if lat is not None and lon is not None:
|
| 329 |
+
logger.info(
|
| 330 |
+
f"Using GPS coordinates for location '{location}': ({lat}, {lon})"
|
| 331 |
+
)
|
| 332 |
+
return GPSPrediction(
|
| 333 |
+
latitude=lat, longitude=lon, analysis="", references=[]
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
prompt = location_prompt(location)
|
| 337 |
+
client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
|
| 338 |
+
|
| 339 |
+
async def api_call():
|
| 340 |
+
# Run the synchronous API call in a thread executor to make it truly async
|
| 341 |
+
loop = asyncio.get_event_loop()
|
| 342 |
+
response = await loop.run_in_executor(
|
| 343 |
+
None,
|
| 344 |
+
lambda: client.models.generate_content(
|
| 345 |
+
model=model_name,
|
| 346 |
+
contents=[prompt],
|
| 347 |
+
config=types.GenerateContentConfig(
|
| 348 |
+
tools=[
|
| 349 |
+
types.Tool(google_search=types.GoogleSearch()),
|
| 350 |
+
],
|
| 351 |
+
temperature=0.1,
|
| 352 |
+
top_p=0.95,
|
| 353 |
+
),
|
| 354 |
+
),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
raw_text = response.text.strip() if response.text is not None else ""
|
| 358 |
+
parsed_json = extract_and_parse_json(raw_text)
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
validated = GPSPrediction.model_validate(parsed_json)
|
| 362 |
+
return validated
|
| 363 |
+
except (ValidationError, ValueError):
|
| 364 |
+
raise ValueError("Empty or invalid LLM response")
|
| 365 |
+
|
| 366 |
+
return await handle_async_api_call_with_retry(
|
| 367 |
+
api_call,
|
| 368 |
+
fallback_result=GPSPrediction(
|
| 369 |
+
latitude=0.0, longitude=0.0, analysis="", references=[]
|
| 370 |
+
),
|
| 371 |
+
error_context=f"Location prediction for '{location}' with {model_name}",
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
async def verification_predict(
|
| 375 |
+
self,
|
| 376 |
+
prediction: LocationPrediction,
|
| 377 |
+
model_name: str = "gemini-2.5-flash",
|
| 378 |
+
image_prediction: bool = True,
|
| 379 |
+
text_prediction: bool = True,
|
| 380 |
+
) -> LocationPrediction:
|
| 381 |
+
"""
|
| 382 |
+
Generate verification prediction based on the provided prediction.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
prediction (dict): Prediction dictionary with latitude, longitude, location, reason, and metadata
|
| 386 |
+
model_name (str): LLM model name to use for verification
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
dict: Verification prediction with latitude, longitude, location, reason, and evidence
|
| 390 |
+
"""
|
| 391 |
+
# Prepare verification data (now async)
|
| 392 |
+
satellite_image_id = await self.data_processor.prepare_location_images(
|
| 393 |
+
prediction=prediction.model_dump(),
|
| 394 |
+
image_prediction=image_prediction,
|
| 395 |
+
text_prediction=text_prediction,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
image_dir = self.image_dir
|
| 399 |
+
|
| 400 |
+
images = []
|
| 401 |
+
if image_prediction:
|
| 402 |
+
if not image_dir.exists():
|
| 403 |
+
raise ValueError(f"Image directory does not exist: {image_dir}")
|
| 404 |
+
|
| 405 |
+
for image_file in image_dir.glob("*.jpg"):
|
| 406 |
+
with open(image_file, "rb") as f:
|
| 407 |
+
image = types.Part.from_bytes(data=f.read(), mime_type="image/jpeg")
|
| 408 |
+
images.append(image)
|
| 409 |
+
|
| 410 |
+
# Prepare verification prompt
|
| 411 |
+
prompt = verification_prompt(
|
| 412 |
+
satellite_image_id=satellite_image_id,
|
| 413 |
+
prediction=prediction.model_dump(),
|
| 414 |
+
prompt_dir=str(self.prompt_dir),
|
| 415 |
+
image_prediction=image_prediction,
|
| 416 |
+
text_prediction=text_prediction,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
client = genai.Client(api_key=os.environ["GOOGLE_CLOUD_API_KEY"])
|
| 420 |
+
|
| 421 |
+
async def api_call():
|
| 422 |
+
# Run the synchronous API call in a thread executor to make it truly async
|
| 423 |
+
loop = asyncio.get_event_loop()
|
| 424 |
+
response = await loop.run_in_executor(
|
| 425 |
+
None,
|
| 426 |
+
lambda: client.models.generate_content(
|
| 427 |
+
model=model_name,
|
| 428 |
+
contents=[*images, prompt],
|
| 429 |
+
config=types.GenerateContentConfig(
|
| 430 |
+
tools=[
|
| 431 |
+
types.Tool(url_context=types.UrlContext()),
|
| 432 |
+
],
|
| 433 |
+
temperature=0.1,
|
| 434 |
+
top_p=0.95,
|
| 435 |
+
),
|
| 436 |
+
),
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
raw_text = response.text.strip() if response.text is not None else ""
|
| 440 |
+
parsed_json = extract_and_parse_json(raw_text)
|
| 441 |
+
|
| 442 |
+
try:
|
| 443 |
+
validated = LocationPrediction.model_validate(parsed_json)
|
| 444 |
+
return validated
|
| 445 |
+
except (ValidationError, ValueError):
|
| 446 |
+
raise ValueError("Empty or invalid LLM response")
|
| 447 |
+
|
| 448 |
+
return await handle_async_api_call_with_retry(
|
| 449 |
+
api_call,
|
| 450 |
+
fallback_result=LocationPrediction(
|
| 451 |
+
latitude=0.0, longitude=0.0, location="", evidence=[]
|
| 452 |
+
),
|
| 453 |
+
error_context=f"Verification prediction with {model_name}",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
async def predict(
|
| 457 |
+
self,
|
| 458 |
+
model_name: str = "gemini-2.5-flash",
|
| 459 |
+
image_prediction: bool = True,
|
| 460 |
+
text_prediction: bool = True,
|
| 461 |
+
) -> LocationPrediction:
|
| 462 |
+
"""
|
| 463 |
+
Complete prediction pipeline without preprocessing (assumes preprocessing already done).
|
| 464 |
+
Used for parallel execution where preprocessing is done once beforehand.
|
| 465 |
+
All major steps run in parallel for maximum speed.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
model_name (str): LLM model name to use
|
| 469 |
+
image_prediction (bool): Whether to use images in prediction
|
| 470 |
+
text_prediction (bool): Whether to use text in prediction
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
dict: Final prediction with latitude, longitude, location, reason, and evidence
|
| 474 |
+
"""
|
| 475 |
+
logger.info(
|
| 476 |
+
f"🚀 Starting multi-modal prediction pipeline with model: {model_name}"
|
| 477 |
+
)
|
| 478 |
+
await self.data_processor.preprocess_input_data()
|
| 479 |
+
# Step 1: Run diversification prediction (this is already parallel internally)
|
| 480 |
+
logger.info(
|
| 481 |
+
f"\n🔄 Running diversification prediction for Image={image_prediction}, Text={text_prediction}..."
|
| 482 |
+
)
|
| 483 |
+
diversification_result = await self.diversification_predict(
|
| 484 |
+
model_name=model_name,
|
| 485 |
+
image_prediction=image_prediction,
|
| 486 |
+
text_prediction=text_prediction,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Step 2: Run location prediction
|
| 490 |
+
location_prediction = await self.location_predict(
|
| 491 |
+
model_name=model_name, location=diversification_result.location
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
logger.info("✅ Location prediction completed:")
|
| 495 |
+
|
| 496 |
+
# Step 3: Update coordinates and evidence from location prediction
|
| 497 |
+
result = diversification_result.model_copy()
|
| 498 |
+
result.longitude = location_prediction.longitude
|
| 499 |
+
result.latitude = location_prediction.latitude
|
| 500 |
+
|
| 501 |
+
# Step 4: Normalize and append location evidence
|
| 502 |
+
if location_prediction.analysis and location_prediction.references:
|
| 503 |
+
location_evidence = Evidence(
|
| 504 |
+
analysis=location_prediction.analysis,
|
| 505 |
+
references=location_prediction.references,
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
location_evidence = Evidence(
|
| 509 |
+
analysis="No specific location analysis provided.",
|
| 510 |
+
references=[],
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Append to result evidence
|
| 514 |
+
result.evidence.append(location_evidence)
|
| 515 |
+
|
| 516 |
+
# Step 5: Run verification prediction
|
| 517 |
+
logger.info(
|
| 518 |
+
f"\n🔄 Running verification prediction for Image={image_prediction}, Text={text_prediction}..."
|
| 519 |
+
)
|
| 520 |
+
result = await self.verification_predict(
|
| 521 |
+
prediction=result,
|
| 522 |
+
model_name=model_name,
|
| 523 |
+
image_prediction=image_prediction,
|
| 524 |
+
text_prediction=text_prediction,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
logger.info(
|
| 528 |
+
f"\n🎯 Final prediction for Image={image_prediction}, Text={text_prediction}:"
|
| 529 |
+
)
|
| 530 |
+
# print(json.dumps(result, indent=2)) # Commented out verbose output
|
| 531 |
+
|
| 532 |
+
return result
|
| 533 |
+
|
| 534 |
+
def get_response(self, prediction: LocationPrediction) -> LocationPrediction:
|
| 535 |
+
"""
|
| 536 |
+
Convert image references in the prediction to base64 strings.
|
| 537 |
+
"""
|
| 538 |
+
for evidence in prediction.evidence:
|
| 539 |
+
for i, ref in enumerate(evidence.references):
|
| 540 |
+
if ref.startswith("image"):
|
| 541 |
+
evidence.references[i] = image_to_base64(self.image_dir / ref)
|
| 542 |
+
return prediction
|
| 543 |
+
|
| 544 |
+
def get_transcript(self) -> str:
|
| 545 |
+
"""
|
| 546 |
+
Get the transcript from the transcript files in the audio directory.
|
| 547 |
+
"""
|
| 548 |
+
transcript = ""
|
| 549 |
+
for transcript_file in self.audio_dir.glob("*.txt"):
|
| 550 |
+
with open(transcript_file, "r", encoding="utf-8") as f:
|
| 551 |
+
logger.info(f"Reading transcript from {transcript_file.name}")
|
| 552 |
+
transcript_data = f.read().strip()
|
| 553 |
+
if transcript_data:
|
| 554 |
+
transcript += f"Transcript for {transcript_file.name}\n"
|
| 555 |
+
transcript += transcript_data
|
| 556 |
+
return transcript
|
| 557 |
+
|
| 558 |
+
def clear_directories(self):
|
| 559 |
+
"""
|
| 560 |
+
Clear the input and prompt directories.
|
| 561 |
+
"""
|
| 562 |
+
delete_dirs = [self.input_dir, self.prompt_dir]
|
| 563 |
+
for dir_path in delete_dirs:
|
| 564 |
+
if os.path.exists(dir_path):
|
| 565 |
+
shutil.rmtree(dir_path)
|
| 566 |
+
logger.info(f"Deleted folder: {dir_path}")
|
| 567 |
+
else:
|
| 568 |
+
logger.info(f"Folder does not exist: {dir_path}")
|
src/prompt/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .factory import diversification_prompt as diversification_prompt
|
| 2 |
+
from .factory import location_prompt as location_prompt
|
| 3 |
+
from .factory import verification_prompt as verification_prompt
|
| 4 |
+
from .factory import Evidence as Evidence
|
| 5 |
+
from .factory import GPSPrediction as GPSPrediction
|
| 6 |
+
from .factory import LocationPrediction as LocationPrediction
|
src/prompt/factory.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
from .template import DIVERSIFICATION_PROMPT, LOCATION_PROMPT, VERIFICATION_PROMPT
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Evidence(BaseModel):
|
| 10 |
+
analysis: str
|
| 11 |
+
references: list[str] = []
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LocationPrediction(BaseModel):
|
| 15 |
+
latitude: float
|
| 16 |
+
longitude: float
|
| 17 |
+
location: str
|
| 18 |
+
evidence: list[Evidence]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GPSPrediction(BaseModel):
|
| 22 |
+
latitude: float
|
| 23 |
+
longitude: float
|
| 24 |
+
analysis: str
|
| 25 |
+
references: list[str]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def rag_prompt(index_search_json: str, n_coords: int | None = None) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Creates a formatted string with GPS coordinates for similar and dissimilar images.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
candidates_gps (list[tuple]): List of (lat, lon) tuples for similar images.
|
| 34 |
+
reverse_gps (list[tuple]): List of (lat, lon) tuples for dissimilar images.
|
| 35 |
+
n_coords (int, optional): Number of coords to include from each list. Defaults to all.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
str: Formatted string with coordinates for reference.
|
| 39 |
+
"""
|
| 40 |
+
if not os.path.exists(index_search_json):
|
| 41 |
+
return ""
|
| 42 |
+
|
| 43 |
+
with open(index_search_json, "r", encoding="utf-8") as file:
|
| 44 |
+
data = json.load(file)
|
| 45 |
+
|
| 46 |
+
candidates_gps = data.get("candidates_gps", [])
|
| 47 |
+
reverse_gps = data.get("reverse_gps", [])
|
| 48 |
+
|
| 49 |
+
if n_coords is not None:
|
| 50 |
+
candidates_gps = candidates_gps[: min(n_coords, len(candidates_gps))]
|
| 51 |
+
reverse_gps = reverse_gps[: min(n_coords, len(reverse_gps))]
|
| 52 |
+
else:
|
| 53 |
+
candidates_gps = candidates_gps
|
| 54 |
+
reverse_gps = reverse_gps
|
| 55 |
+
|
| 56 |
+
candidates_str = (
|
| 57 |
+
"[" + ", ".join(f"[{lat}, {lon}]" for (lat, lon) in candidates_gps) + "]"
|
| 58 |
+
)
|
| 59 |
+
reverse_str = "[" + ", ".join(f"[{lat}, {lon}]" for (lat, lon) in reverse_gps) + "]"
|
| 60 |
+
return f"For your reference, these are coordinates of some similar images: {candidates_str}, and these are coordinates of some dissimilar images: {reverse_str}."
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def metadata_prompt(metadata_file_path: str) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Reads a metadata JSON file and returns a formatted string combining all fields.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
metadata_file_path (str): Path to the metadata JSON file
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
str: Formatted string with all metadata fields combined
|
| 72 |
+
"""
|
| 73 |
+
if not metadata_file_path or not os.path.exists(metadata_file_path):
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
with open(metadata_file_path, "r", encoding="utf-8") as file:
|
| 78 |
+
metadata = json.load(file)
|
| 79 |
+
|
| 80 |
+
if not metadata:
|
| 81 |
+
return ""
|
| 82 |
+
|
| 83 |
+
metadata_parts = []
|
| 84 |
+
|
| 85 |
+
if "location" in metadata and metadata["location"]:
|
| 86 |
+
metadata_parts.append(f"Location: {metadata['location']}")
|
| 87 |
+
|
| 88 |
+
if "violence level" in metadata and metadata["violence level"]:
|
| 89 |
+
metadata_parts.append(f"Violence level: {metadata['violence level']}")
|
| 90 |
+
|
| 91 |
+
if "title" in metadata and metadata["title"]:
|
| 92 |
+
metadata_parts.append(f"Title: {metadata['title']}")
|
| 93 |
+
|
| 94 |
+
if "social media link" in metadata and metadata["social media link"]:
|
| 95 |
+
metadata_parts.append(f"Social media link: {metadata['social media link']}")
|
| 96 |
+
|
| 97 |
+
if "description" in metadata and metadata["description"]:
|
| 98 |
+
metadata_parts.append(f"Description: {metadata['description']}")
|
| 99 |
+
|
| 100 |
+
if "category" in metadata and metadata["category"]:
|
| 101 |
+
metadata_parts.append(f"Category: {metadata['category']}")
|
| 102 |
+
|
| 103 |
+
if not metadata_parts:
|
| 104 |
+
return ""
|
| 105 |
+
|
| 106 |
+
metadata_string = "Metadata for the image is: "
|
| 107 |
+
return metadata_string + ". ".join(metadata_parts) + "."
|
| 108 |
+
|
| 109 |
+
except Exception:
|
| 110 |
+
return ""
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def search_prompt(search_candidates: list[str], n_search: int | None = None) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Formats search candidate links into a prompt string.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
search_candidates (list[str]): List of candidate URLs from image search
|
| 119 |
+
n_search (int): Number of results to include (default: 5)
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
str: Formatted string with candidate links, each on a new line
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
>>> candidates = search_prompt(["https://example1.com", "https://example2.com"], n_search=3)
|
| 126 |
+
>>> print(candidates)
|
| 127 |
+
Similar image can be found in those links:
|
| 128 |
+
https://example1.com
|
| 129 |
+
https://example2.com
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
if not search_candidates or not isinstance(search_candidates, list):
|
| 133 |
+
return ""
|
| 134 |
+
|
| 135 |
+
EXCLUDE_DOMAINS = [
|
| 136 |
+
"x.com",
|
| 137 |
+
"twitter.com",
|
| 138 |
+
"linkedin.com",
|
| 139 |
+
"bbc.com",
|
| 140 |
+
"bbc.co.uk",
|
| 141 |
+
"instagram.com",
|
| 142 |
+
"tiktok.com",
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
for domain in EXCLUDE_DOMAINS:
|
| 146 |
+
search_candidates = [url for url in search_candidates if domain not in url]
|
| 147 |
+
|
| 148 |
+
if n_search is not None:
|
| 149 |
+
search_candidates = search_candidates[: min(n_search, len(search_candidates))]
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
prompt = "\n".join(search_candidates)
|
| 153 |
+
return prompt
|
| 154 |
+
|
| 155 |
+
except Exception:
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def image_search_prompt(image_search_json: str, n_search: int | None = None) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Reads all JSON files in the base directory's image_search folder and combines links.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
base_dir (str): Path to the base directory containing image search JSON files
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
str: Combined search prompt string
|
| 168 |
+
"""
|
| 169 |
+
pages_with_matching_images = set()
|
| 170 |
+
full_matching_images = set()
|
| 171 |
+
partial_matching_images = set()
|
| 172 |
+
|
| 173 |
+
with open(image_search_json, "r", encoding="utf-8") as file:
|
| 174 |
+
data_list = json.load(file)
|
| 175 |
+
for json_data in data_list:
|
| 176 |
+
if "pages_with_matching_images" in json_data:
|
| 177 |
+
pages_with_matching_images.update(
|
| 178 |
+
json_data["pages_with_matching_images"]
|
| 179 |
+
)
|
| 180 |
+
elif "full_matching_images" in json_data:
|
| 181 |
+
full_matching_images.update(json_data["full_matching_images"])
|
| 182 |
+
elif "partial_matching_images" in json_data:
|
| 183 |
+
partial_matching_images.update(json_data["partial_matching_images"])
|
| 184 |
+
|
| 185 |
+
if (
|
| 186 |
+
not pages_with_matching_images
|
| 187 |
+
and not full_matching_images
|
| 188 |
+
and not partial_matching_images
|
| 189 |
+
):
|
| 190 |
+
return ""
|
| 191 |
+
|
| 192 |
+
prompt = "Those are pages with matching images:\n"
|
| 193 |
+
prompt += search_prompt(list(pages_with_matching_images), n_search=n_search)
|
| 194 |
+
# prompt += "\n\nThose are full matching images:\n"
|
| 195 |
+
# prompt += search_prompt(list(full_matching_images), n_search=n_search)
|
| 196 |
+
# prompt += "\n\nThose are partial matching images:\n"
|
| 197 |
+
# prompt += search_prompt(list(partial_matching_images), n_search=n_search)
|
| 198 |
+
|
| 199 |
+
return prompt
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def search_content_prompt(search_content_json: str) -> str:
|
| 203 |
+
"""
|
| 204 |
+
Reads a JSON file containing search content and returns a formatted string.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
search_content_json (str): Path to the JSON file with search content
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
str: Formatted string with all search content links
|
| 211 |
+
"""
|
| 212 |
+
if not os.path.exists(search_content_json):
|
| 213 |
+
return ""
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
with open(search_content_json, "r", encoding="utf-8") as file:
|
| 217 |
+
data = json.load(file)
|
| 218 |
+
|
| 219 |
+
if not data or not isinstance(data, list):
|
| 220 |
+
return ""
|
| 221 |
+
|
| 222 |
+
prompt = json.dumps(data, indent=2)
|
| 223 |
+
return prompt
|
| 224 |
+
|
| 225 |
+
except Exception:
|
| 226 |
+
return ""
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def transcript_prompt(audio_dir: str) -> str:
|
| 230 |
+
"""
|
| 231 |
+
Reads all transcript text files in the audio directory and returns a formatted string.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
audio_dir (str): Path to the audio directory containing transcript files
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
str: Combined transcript content formatted as a prompt
|
| 238 |
+
"""
|
| 239 |
+
if not os.path.exists(audio_dir):
|
| 240 |
+
return ""
|
| 241 |
+
|
| 242 |
+
transcript_content = []
|
| 243 |
+
|
| 244 |
+
for txt_file in os.listdir(audio_dir):
|
| 245 |
+
if txt_file.endswith(".txt"):
|
| 246 |
+
txt_path = os.path.join(audio_dir, txt_file)
|
| 247 |
+
with open(txt_path, "r", encoding="utf-8") as file:
|
| 248 |
+
transcript_content.append(file.read().strip())
|
| 249 |
+
|
| 250 |
+
combined_transcript = "\n".join(transcript_content)
|
| 251 |
+
return (
|
| 252 |
+
f"This is the transcript of the video: {combined_transcript}"
|
| 253 |
+
if combined_transcript
|
| 254 |
+
else ""
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def combine_prompt_data(
|
| 259 |
+
prompt_dir: str,
|
| 260 |
+
n_search: int | None = None,
|
| 261 |
+
n_coords: int | None = None,
|
| 262 |
+
image_prediction: bool = True,
|
| 263 |
+
text_prediction: bool = True,
|
| 264 |
+
) -> str:
|
| 265 |
+
"""
|
| 266 |
+
Combines all prompt data into one comprehensive prompt string.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
base_dir (str): Path to the base directory
|
| 270 |
+
candidates_gps (list[tuple]): GPS coordinates for similar images (for RAG)
|
| 271 |
+
reverse_gps (list[tuple]): GPS coordinates for dissimilar images (for RAG)
|
| 272 |
+
n_search (int): Number of search results to include (default: 5)
|
| 273 |
+
n_coords (int, optional): Number of coordinates to include in RAG
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
str: Combined prompt string
|
| 277 |
+
|
| 278 |
+
Example:
|
| 279 |
+
>>> prompt = combine_prompts(
|
| 280 |
+
... base_dir="path/to/base_dir",
|
| 281 |
+
... candidates_gps=[(40.7128, -74.0060)],
|
| 282 |
+
... reverse_gps=[(51.5074, -0.1278)]
|
| 283 |
+
... )
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
prompt_parts = []
|
| 287 |
+
|
| 288 |
+
# 1. RAG prompt (optional)
|
| 289 |
+
if n_coords is not None:
|
| 290 |
+
rag_text = rag_prompt(os.path.join(prompt_dir, "index_search.json"), n_coords)
|
| 291 |
+
prompt_parts.append(rag_text)
|
| 292 |
+
|
| 293 |
+
# 2. Metadata prompt
|
| 294 |
+
if text_prediction:
|
| 295 |
+
metadata_text = metadata_prompt(os.path.join(prompt_dir, "metadata.json"))
|
| 296 |
+
if metadata_text:
|
| 297 |
+
prompt_parts.append(metadata_text)
|
| 298 |
+
|
| 299 |
+
# 3. Search prompt
|
| 300 |
+
if image_prediction:
|
| 301 |
+
image_search_text = search_content_prompt(
|
| 302 |
+
os.path.join(prompt_dir, "image_search_content.json")
|
| 303 |
+
)
|
| 304 |
+
if image_search_text:
|
| 305 |
+
prompt_parts.append(image_search_text)
|
| 306 |
+
|
| 307 |
+
if text_prediction:
|
| 308 |
+
search_content_text = search_content_prompt(
|
| 309 |
+
os.path.join(prompt_dir, "text_search_content.json")
|
| 310 |
+
)
|
| 311 |
+
if search_content_text:
|
| 312 |
+
prompt_parts.append(search_content_text)
|
| 313 |
+
|
| 314 |
+
# 4. Transcript prompt
|
| 315 |
+
transcript_text = transcript_prompt(os.path.join(prompt_dir, "audio"))
|
| 316 |
+
if transcript_text:
|
| 317 |
+
prompt_parts.append(transcript_text)
|
| 318 |
+
|
| 319 |
+
# Combine all parts with double newlines for readability
|
| 320 |
+
combined_prompt = "\n\n".join(part for part in prompt_parts if part.strip())
|
| 321 |
+
|
| 322 |
+
return combined_prompt
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def diversification_prompt(
|
| 326 |
+
prompt_dir: str,
|
| 327 |
+
n_search: int | None = None,
|
| 328 |
+
n_coords: int | None = None,
|
| 329 |
+
image_prediction: bool = True,
|
| 330 |
+
text_prediction: bool = True,
|
| 331 |
+
) -> str:
|
| 332 |
+
"""
|
| 333 |
+
Combines all prompts into one comprehensive prompt string.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
base_dir (str): Path to the base directory
|
| 337 |
+
candidates_gps (list[tuple]): GPS coordinates for similar images (for RAG)
|
| 338 |
+
reverse_gps (list[tuple]): GPS coordinates for dissimilar images (for RAG)
|
| 339 |
+
n_search (int): Number of search results to include (default: 5)
|
| 340 |
+
n_coords (int, optional): Number of coordinates to include in RAG
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
str: Combined prompt string
|
| 344 |
+
|
| 345 |
+
Example:
|
| 346 |
+
>>> prompt = combine_prompts(
|
| 347 |
+
... base_dir="path/to/base_dir",
|
| 348 |
+
... candidates_gps=[(40.7128, -74.0060)],
|
| 349 |
+
... reverse_gps=[(51.5074, -0.1278)]
|
| 350 |
+
... )
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
prompt_data = combine_prompt_data(
|
| 354 |
+
prompt_dir,
|
| 355 |
+
n_search=n_search,
|
| 356 |
+
n_coords=n_coords,
|
| 357 |
+
image_prediction=image_prediction,
|
| 358 |
+
text_prediction=text_prediction,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
prompt = DIVERSIFICATION_PROMPT.strip().format(prompt_data=prompt_data)
|
| 362 |
+
|
| 363 |
+
return prompt
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def location_prompt(location: str) -> str:
|
| 367 |
+
"""
|
| 368 |
+
Creates a prompt string for the given location.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
location (str): The location to include in the prompt.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
str: Formatted string with the location.
|
| 375 |
+
"""
|
| 376 |
+
if not location:
|
| 377 |
+
return ""
|
| 378 |
+
|
| 379 |
+
prompt = LOCATION_PROMPT.strip()
|
| 380 |
+
prompt = prompt.format(location=location)
|
| 381 |
+
|
| 382 |
+
return prompt
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def verification_prompt(
|
| 386 |
+
satellite_image_id: int,
|
| 387 |
+
prediction: dict,
|
| 388 |
+
prompt_dir: str,
|
| 389 |
+
n_search: int | None = None,
|
| 390 |
+
n_coords: int | None = None,
|
| 391 |
+
image_prediction: bool = True,
|
| 392 |
+
text_prediction: bool = True,
|
| 393 |
+
) -> str:
|
| 394 |
+
"""
|
| 395 |
+
Creates a verification prompt string with the provided data and prediction.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
prompt_data (str): The prompt data to include.
|
| 399 |
+
prediction (str): The prediction to verify.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
str: Formatted verification prompt string.
|
| 403 |
+
"""
|
| 404 |
+
prompt_data = combine_prompt_data(
|
| 405 |
+
prompt_dir,
|
| 406 |
+
n_search=n_search,
|
| 407 |
+
n_coords=n_coords,
|
| 408 |
+
image_prediction=image_prediction,
|
| 409 |
+
text_prediction=text_prediction,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
prompt = VERIFICATION_PROMPT.strip().format(
|
| 413 |
+
prompt_data=prompt_data,
|
| 414 |
+
prediction=json.dumps(prediction, indent=2),
|
| 415 |
+
satellite_image_id=f"{satellite_image_id:03d}",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
return prompt
|
src/prompt/fetch/content_fetch.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Literal, TypedDict
|
| 6 |
+
|
| 7 |
+
from playwright.async_api import Page, async_playwright
|
| 8 |
+
|
| 9 |
+
READABILITY_JS_URL = "https://unpkg.com/@mozilla/readability@0.4.4/Readability.js"
|
| 10 |
+
logger = logging.getLogger("uvicorn.error")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PageText(TypedDict):
|
| 14 |
+
url: str
|
| 15 |
+
text: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
WaitUntil = Literal["load", "domcontentloaded", "networkidle", "commit"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
async def _inject_readability(page: Page) -> None:
|
| 22 |
+
is_html = await page.evaluate("() => document.documentElement.nodeName === 'HTML'")
|
| 23 |
+
if not is_html:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
await page.add_script_tag(url=READABILITY_JS_URL)
|
| 27 |
+
await page.add_script_tag(
|
| 28 |
+
content="window.__readability__ = new Readability(document.cloneNode(true));"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def _fetch_text(page: Page, url: str, wait_until: WaitUntil) -> str:
|
| 33 |
+
await page.goto(url, wait_until=wait_until)
|
| 34 |
+
await page.wait_for_timeout(1000)
|
| 35 |
+
|
| 36 |
+
# Attempt Readability.js parsing first
|
| 37 |
+
try:
|
| 38 |
+
await _inject_readability(page)
|
| 39 |
+
readability_text = await page.evaluate(
|
| 40 |
+
"() => window.__readability__.parse()?.textContent"
|
| 41 |
+
)
|
| 42 |
+
if readability_text:
|
| 43 |
+
return readability_text.strip()
|
| 44 |
+
except BaseException as _:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# Fallback: Twitter specific logic
|
| 48 |
+
try:
|
| 49 |
+
tweet_text = await page.locator(
|
| 50 |
+
"article div[data-testid='tweetText']"
|
| 51 |
+
).all_inner_texts()
|
| 52 |
+
if tweet_text:
|
| 53 |
+
return "\n".join(tweet_text)
|
| 54 |
+
except BaseException as _:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
# Final fallback: full body text
|
| 58 |
+
return await page.evaluate("() => document.body.innerText")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def fetch_text(
|
| 62 |
+
url: str, headless: bool = False, wait_until: WaitUntil = "load"
|
| 63 |
+
) -> PageText:
|
| 64 |
+
async with async_playwright() as pw:
|
| 65 |
+
browser = await pw.chromium.launch_persistent_context(
|
| 66 |
+
user_data_dir="",
|
| 67 |
+
channel="chrome",
|
| 68 |
+
headless=headless,
|
| 69 |
+
no_viewport=True,
|
| 70 |
+
)
|
| 71 |
+
page = await browser.new_page()
|
| 72 |
+
text = await _fetch_text(page, url, wait_until)
|
| 73 |
+
await browser.close()
|
| 74 |
+
|
| 75 |
+
return PageText(url=url, text=text)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
async def fetch_texts(
|
| 79 |
+
urls: list[str], headless: bool = False, wait_until: WaitUntil = "load"
|
| 80 |
+
) -> list[PageText | BaseException]:
|
| 81 |
+
async with async_playwright() as pw:
|
| 82 |
+
browser = await pw.chromium.launch_persistent_context(
|
| 83 |
+
user_data_dir="",
|
| 84 |
+
channel="chrome",
|
| 85 |
+
headless=True,
|
| 86 |
+
no_viewport=True,
|
| 87 |
+
)
|
| 88 |
+
# browser = await pw.chromium.launch_persistent_context(
|
| 89 |
+
# user_data_dir="/tmp/playwright_profile",
|
| 90 |
+
# headless=True,
|
| 91 |
+
# no_viewport=True,
|
| 92 |
+
# )
|
| 93 |
+
pages = [await browser.new_page() for _ in urls]
|
| 94 |
+
|
| 95 |
+
tasks = [_fetch_text(page, url, wait_until) for page, url in zip(pages, urls)]
|
| 96 |
+
results_raw = await asyncio.gather(*tasks, return_exceptions=True)
|
| 97 |
+
await browser.close()
|
| 98 |
+
|
| 99 |
+
results: list[PageText | BaseException] = []
|
| 100 |
+
for url, result in zip(urls, results_raw):
|
| 101 |
+
if isinstance(result, BaseException):
|
| 102 |
+
results.append(result)
|
| 103 |
+
else:
|
| 104 |
+
results.append(PageText(url=url, text=result))
|
| 105 |
+
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
async def fetch_links_to_json(
|
| 110 |
+
links: list[str],
|
| 111 |
+
output_path: str,
|
| 112 |
+
headless: bool = False,
|
| 113 |
+
wait_until: WaitUntil = "load",
|
| 114 |
+
max_content_length: int = 5000,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Fetch content from a list of links and save to a JSON file.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
links: List of URLs to fetch content from
|
| 121 |
+
output_path: Path where the JSON file will be saved
|
| 122 |
+
headless: Whether to run browser in headless mode
|
| 123 |
+
wait_until: When to consider page loading complete
|
| 124 |
+
max_content_length: Maximum number of characters to keep from each page content
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
None (saves results to JSON file)
|
| 128 |
+
"""
|
| 129 |
+
logger.info(f"📥 Fetching content from {len(links)} links...")
|
| 130 |
+
|
| 131 |
+
# Fetch content from all links
|
| 132 |
+
results = await fetch_texts(links, headless=headless, wait_until=wait_until)
|
| 133 |
+
|
| 134 |
+
# Process results into the desired format
|
| 135 |
+
json_data = []
|
| 136 |
+
for i, (link, result) in enumerate(zip(links, results)):
|
| 137 |
+
logger.info(f" Processing {i + 1}/{len(links)}: {link}")
|
| 138 |
+
|
| 139 |
+
if isinstance(result, BaseException):
|
| 140 |
+
# Handle errors gracefully
|
| 141 |
+
json_data.append({"link": link, "content": "Fail to fetch content..."})
|
| 142 |
+
else:
|
| 143 |
+
# Successfully fetched content - apply length limit
|
| 144 |
+
content = result["text"]
|
| 145 |
+
if len(content) > max_content_length:
|
| 146 |
+
content = (
|
| 147 |
+
content[:max_content_length]
|
| 148 |
+
+ "... [content truncated due to length limit]"
|
| 149 |
+
)
|
| 150 |
+
logger.info(
|
| 151 |
+
f"✂️ Content truncated from {len(result['text'])} to {max_content_length} characters"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
json_data.append({"link": link, "content": content})
|
| 155 |
+
|
| 156 |
+
# Ensure output directory exists
|
| 157 |
+
output_file = Path(output_path)
|
| 158 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 159 |
+
|
| 160 |
+
# Save to JSON file
|
| 161 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 162 |
+
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
| 163 |
+
|
| 164 |
+
logger.info(f"💾 Saved content from {len(links)} links to {output_path}")
|
| 165 |
+
|
| 166 |
+
# Print summary
|
| 167 |
+
successful = sum(
|
| 168 |
+
1 for item in json_data if not item["content"].startswith("Error fetching")
|
| 169 |
+
)
|
| 170 |
+
failed = len(json_data) - successful
|
| 171 |
+
logger.info(f"📊 Summary: {successful} successful, {failed} failed")
|
src/prompt/fetch/satellite_fetch.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
from geopy import Point
|
| 5 |
+
from geopy.distance import distance
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger("uvicorn.error")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def meter_offsets(lat: float, lon: float, extend: float) -> tuple[float, float]:
|
| 11 |
+
"""
|
| 12 |
+
Returns (lat_offset, lon_offset) in degrees for a given
|
| 13 |
+
center point (lat, lon) and radial distance in meters (extend).
|
| 14 |
+
"""
|
| 15 |
+
origin = Point(lat, lon)
|
| 16 |
+
# Move north (bearing=0°) and east (bearing=90°)
|
| 17 |
+
north = distance(meters=extend).destination(origin, bearing=0)
|
| 18 |
+
east = distance(meters=extend).destination(origin, bearing=90)
|
| 19 |
+
return north.latitude - lat, east.longitude - lon
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fetch_satellite_image(
|
| 23 |
+
lat: float, lon: float, extend: float, output_path: str = "esri_sat.png"
|
| 24 |
+
) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Fetches a satellite PNG from Esri's World Imagery service.
|
| 27 |
+
|
| 28 |
+
Parameters:
|
| 29 |
+
- lat: Latitude of the center point (decimal degrees).
|
| 30 |
+
- lon: Longitude of the center point (decimal degrees).
|
| 31 |
+
- extend: Buffer distance from center in meters (radius).
|
| 32 |
+
- output_path: File path to save the resulting PNG.
|
| 33 |
+
|
| 34 |
+
Attempts the highest resolution (1024x1024) first,
|
| 35 |
+
halving the dimensions on failure until success.
|
| 36 |
+
Retries up to 3 times if all size attempts fail.
|
| 37 |
+
"""
|
| 38 |
+
# Compute lat/lon degree offsets using geopy
|
| 39 |
+
lat_offset, lon_offset = meter_offsets(lat, lon, extend)
|
| 40 |
+
|
| 41 |
+
# Compute bounding box in lon/lat
|
| 42 |
+
minx = lon - lon_offset
|
| 43 |
+
miny = lat - lat_offset
|
| 44 |
+
maxx = lon + lon_offset
|
| 45 |
+
maxy = lat + lat_offset
|
| 46 |
+
|
| 47 |
+
base_url = (
|
| 48 |
+
"https://server.arcgisonline.com/ArcGIS/rest/services/"
|
| 49 |
+
"World_Imagery/MapServer/export"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Retry up to 3 times
|
| 53 |
+
for attempt in range(3):
|
| 54 |
+
logger.info(f"Attempt {attempt + 1}/3 to fetch satellite image...")
|
| 55 |
+
|
| 56 |
+
# Try descending sizes until success
|
| 57 |
+
size = 1024
|
| 58 |
+
while size >= 128:
|
| 59 |
+
params = {
|
| 60 |
+
"bbox": f"{minx},{miny},{maxx},{maxy}",
|
| 61 |
+
"bboxSR": "4326",
|
| 62 |
+
"size": f"{size},{size}",
|
| 63 |
+
"format": "png",
|
| 64 |
+
"f": "image",
|
| 65 |
+
}
|
| 66 |
+
try:
|
| 67 |
+
response = httpx.get(base_url, params=params, timeout=30.0)
|
| 68 |
+
if response.status_code == 200:
|
| 69 |
+
with open(output_path, "wb") as f:
|
| 70 |
+
f.write(response.content)
|
| 71 |
+
logger.info(f"Saved Esri image to {output_path} ({size}x{size})")
|
| 72 |
+
return
|
| 73 |
+
else:
|
| 74 |
+
logger.info(
|
| 75 |
+
f"Failed at size {size} (status {response.status_code}), trying {size // 2}"
|
| 76 |
+
)
|
| 77 |
+
size //= 2
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"Network error at size {size}: {e}, trying {size // 2}")
|
| 80 |
+
size //= 2
|
| 81 |
+
|
| 82 |
+
# If this attempt failed for all sizes, log and continue to next attempt
|
| 83 |
+
if attempt < 2: # Don't print this message on the last attempt
|
| 84 |
+
logger.info(f"Attempt {attempt + 1} failed for all sizes, retrying...")
|
| 85 |
+
|
| 86 |
+
# If all attempts fail
|
| 87 |
+
logger.warning("Unable to fetch Esri imagery: all retry attempts failed.")
|
src/prompt/preprocess/keyframe_extract.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from google.cloud import videointelligence_v1 as vi
|
| 7 |
+
from scipy.spatial.distance import cdist
|
| 8 |
+
from sklearn.metrics import silhouette_score
|
| 9 |
+
|
| 10 |
+
# Set up logger
|
| 11 |
+
logger = logging.getLogger("uvicorn.error")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def detect_shot_intervals_local(video_path: str) -> list[tuple[float, float]]:
|
| 15 |
+
logger.info(f"Detecting shot intervals for video: {video_path}")
|
| 16 |
+
client = vi.VideoIntelligenceServiceClient()
|
| 17 |
+
with open(video_path, "rb") as f:
|
| 18 |
+
input_content = f.read()
|
| 19 |
+
|
| 20 |
+
op = client.annotate_video(
|
| 21 |
+
request={
|
| 22 |
+
"input_content": input_content,
|
| 23 |
+
"features": [vi.Feature.SHOT_CHANGE_DETECTION],
|
| 24 |
+
}
|
| 25 |
+
)
|
| 26 |
+
response = op.result(timeout=300)
|
| 27 |
+
if not response or not response.annotation_results:
|
| 28 |
+
logger.error("No annotation_results found in video intelligence response.")
|
| 29 |
+
return []
|
| 30 |
+
result = response.annotation_results[0]
|
| 31 |
+
intervals = []
|
| 32 |
+
for shot in result.shot_annotations:
|
| 33 |
+
start = (
|
| 34 |
+
shot.start_time_offset.seconds + shot.start_time_offset.microseconds / 1e6
|
| 35 |
+
)
|
| 36 |
+
end = shot.end_time_offset.seconds + shot.end_time_offset.microseconds / 1e6
|
| 37 |
+
intervals.append((start, end))
|
| 38 |
+
logger.info(f"Detected {len(intervals)} shot intervals.")
|
| 39 |
+
return intervals
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def color_histogram(img: np.ndarray) -> np.ndarray:
|
| 43 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 44 |
+
hist = cv2.calcHist([hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
|
| 45 |
+
return cv2.normalize(hist, hist).flatten()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def sample_frames_per_shot(
|
| 49 |
+
video_path: str, start: float, end: float, step: float = 1.0
|
| 50 |
+
) -> list[np.ndarray]:
|
| 51 |
+
# logger.info(f"Sampling frames from {start:.2f}s to {end:.2f}s every {step:.2f}s")
|
| 52 |
+
cap = cv2.VideoCapture(video_path)
|
| 53 |
+
frames = []
|
| 54 |
+
t = start
|
| 55 |
+
while t < end:
|
| 56 |
+
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
|
| 57 |
+
ret, frame = cap.read()
|
| 58 |
+
if not ret:
|
| 59 |
+
logger.warning(f"Failed to read frame at {t:.2f}s")
|
| 60 |
+
break
|
| 61 |
+
frames.append(frame)
|
| 62 |
+
t += step
|
| 63 |
+
cap.release()
|
| 64 |
+
# logger.info(f"Sampled {len(frames)} frames for shot interval.")
|
| 65 |
+
return frames
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def kmeans_init(features: np.ndarray):
|
| 69 |
+
n, _ = features.shape
|
| 70 |
+
k = int(np.sqrt(n)) or 1
|
| 71 |
+
idx = np.random.choice(n, k, replace=False)
|
| 72 |
+
centers = features[idx]
|
| 73 |
+
clusters = np.argmin(cdist(features, centers), axis=1)
|
| 74 |
+
return clusters, centers
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def kmeans_silhouette(features: np.ndarray):
|
| 78 |
+
k = max(int(np.sqrt(len(features))), 2)
|
| 79 |
+
best_k, best_score = k, -1
|
| 80 |
+
clusters, centers = kmeans_init(features)
|
| 81 |
+
best_centers = centers.copy()
|
| 82 |
+
while k > 2:
|
| 83 |
+
d = cdist(centers, centers)
|
| 84 |
+
np.fill_diagonal(d, np.inf)
|
| 85 |
+
i, j = np.unravel_index(np.argmin(d), d.shape)
|
| 86 |
+
clusters = np.where(clusters == j, i, clusters)
|
| 87 |
+
clusters = np.where(clusters > j, clusters - 1, clusters)
|
| 88 |
+
new_centers = []
|
| 89 |
+
for cid in range(k - 1):
|
| 90 |
+
cluster_feats = features[clusters == cid]
|
| 91 |
+
if cluster_feats.size == 0:
|
| 92 |
+
continue
|
| 93 |
+
mean_vec = np.mean(cluster_feats, axis=0)
|
| 94 |
+
idx_close = np.argmin(np.linalg.norm(cluster_feats - mean_vec, axis=1))
|
| 95 |
+
new_centers.append(cluster_feats[idx_close])
|
| 96 |
+
centers = new_centers
|
| 97 |
+
k -= 1
|
| 98 |
+
if len(np.unique(clusters)) > 1:
|
| 99 |
+
score = silhouette_score(features, clusters)
|
| 100 |
+
if score > best_score:
|
| 101 |
+
best_score, best_k = score, k
|
| 102 |
+
best_centers = centers.copy()
|
| 103 |
+
center_indices = []
|
| 104 |
+
for c in best_centers:
|
| 105 |
+
matches = np.where((features == c).all(axis=1))[0]
|
| 106 |
+
if matches.size > 0:
|
| 107 |
+
center_indices.append(int(matches[0]))
|
| 108 |
+
# logger.info(f"KMeans silhouette: best_k={best_k}, best_score={best_score:.4f}")
|
| 109 |
+
return best_k, best_centers, center_indices
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def redundancy_filter(
|
| 113 |
+
video_path: str, indices: list[int], threshold: float
|
| 114 |
+
) -> list[int]:
|
| 115 |
+
# logger.info(f"Filtering redundant frames with threshold {threshold}")
|
| 116 |
+
histos = []
|
| 117 |
+
cap = cv2.VideoCapture(video_path)
|
| 118 |
+
for idx in indices:
|
| 119 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 120 |
+
ret, frame = cap.read()
|
| 121 |
+
if ret:
|
| 122 |
+
histos.append(color_histogram(frame))
|
| 123 |
+
cap.release()
|
| 124 |
+
keep = []
|
| 125 |
+
for i, h in enumerate(histos):
|
| 126 |
+
if not any(
|
| 127 |
+
np.dot(h, nh) / (np.linalg.norm(h) * np.linalg.norm(nh)) > threshold
|
| 128 |
+
for nh in histos[:i]
|
| 129 |
+
):
|
| 130 |
+
keep.append(indices[i])
|
| 131 |
+
# logger.info(f"Filtered down to {len(keep)} non-redundant frames.")
|
| 132 |
+
return keep
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def extract_and_save_keyframes(
|
| 136 |
+
video_path: str,
|
| 137 |
+
output_dir: str,
|
| 138 |
+
start_index: int = 0,
|
| 139 |
+
step: float = 1.0,
|
| 140 |
+
threshold: float = 0.7,
|
| 141 |
+
k_min: int = 2,
|
| 142 |
+
k_max: int = 8,
|
| 143 |
+
) -> int:
|
| 144 |
+
logger.info(f"Starting keyframe extraction for {video_path}")
|
| 145 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 146 |
+
|
| 147 |
+
# Get FPS to convert seconds to frame indices
|
| 148 |
+
cap_meta = cv2.VideoCapture(video_path)
|
| 149 |
+
video_fps = cap_meta.get(cv2.CAP_PROP_FPS) or 1.0
|
| 150 |
+
cap_meta.release()
|
| 151 |
+
|
| 152 |
+
intervals = detect_shot_intervals_local(video_path)
|
| 153 |
+
cap = cv2.VideoCapture(video_path)
|
| 154 |
+
output_idx = start_index
|
| 155 |
+
|
| 156 |
+
for shot_idx, (start, end) in enumerate(intervals):
|
| 157 |
+
# logger.info(
|
| 158 |
+
# f"Processing shot {shot_idx + 1}/{len(intervals)}: {start:.2f}s to {end:.2f}s"
|
| 159 |
+
# )
|
| 160 |
+
|
| 161 |
+
# Sample frames & extract features
|
| 162 |
+
frames = sample_frames_per_shot(video_path, start, end, step)
|
| 163 |
+
feats = (
|
| 164 |
+
np.vstack([color_histogram(f) for f in frames])
|
| 165 |
+
if frames
|
| 166 |
+
else np.empty((0,))
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Determine intra-shot keyframe indices
|
| 170 |
+
if feats.size < k_min or feats.ndim == 1:
|
| 171 |
+
idxs = list(range(len(frames)))
|
| 172 |
+
else:
|
| 173 |
+
_, centers, cidxs = kmeans_silhouette(feats)
|
| 174 |
+
idxs = cidxs
|
| 175 |
+
|
| 176 |
+
# Map to global frame numbers and dedupe
|
| 177 |
+
global_idxs = [int(start * video_fps) + i for i in idxs]
|
| 178 |
+
filtered = redundancy_filter(video_path, global_idxs, threshold)
|
| 179 |
+
|
| 180 |
+
# Save each keyframe sequentially into output_dir
|
| 181 |
+
for frame_no in filtered:
|
| 182 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
|
| 183 |
+
ret, frame = cap.read()
|
| 184 |
+
if not ret:
|
| 185 |
+
continue
|
| 186 |
+
out_path = os.path.join(output_dir, f"image_{output_idx:03d}.jpg")
|
| 187 |
+
cv2.imwrite(out_path, frame)
|
| 188 |
+
output_idx += 1
|
| 189 |
+
logger.info(
|
| 190 |
+
f"Shot {shot_idx + 1}: saved {len(filtered)} keyframes. Total so far: {output_idx}"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
cap.release()
|
| 194 |
+
logger.info(f"Extraction complete. Total frames saved: {output_idx}")
|
| 195 |
+
return output_idx
|
src/prompt/preprocess/video_transcribe.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
import whisper
|
| 13 |
+
from moviepy import *
|
| 14 |
+
from open_clip import create_model_and_transforms
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("uvicorn.error")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def extract_audio(video_path: str, output_dir: str) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Extract audio from video and save as WAV file.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
video_path (str): Path to input video file.
|
| 25 |
+
output_dir (str): Directory to save the audio file.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
str: Path to the saved audio file.
|
| 29 |
+
"""
|
| 30 |
+
video_name = Path(video_path).stem
|
| 31 |
+
audio_path = Path(output_dir) / f"{video_name}.wav"
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
video = VideoFileClip(str(video_path))
|
| 35 |
+
audio = video.audio
|
| 36 |
+
|
| 37 |
+
if audio is not None:
|
| 38 |
+
audio.write_audiofile(str(audio_path), logger="bar")
|
| 39 |
+
audio.close()
|
| 40 |
+
video.close()
|
| 41 |
+
|
| 42 |
+
if not audio_path.exists():
|
| 43 |
+
raise RuntimeError("Audio file was not created")
|
| 44 |
+
|
| 45 |
+
return str(audio_path)
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Error extracting audio: {str(e)}")
|
| 49 |
+
return ""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def transcribe_audio(audio_path: str, model_name: str = "base") -> str:
|
| 53 |
+
"""
|
| 54 |
+
Transcribe audio using Whisper.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio_path (str): Path to the audio file.
|
| 58 |
+
model_name (str): Whisper model name.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: Transcription text.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
model = whisper.load_model(model_name)
|
| 65 |
+
result = model.transcribe(str(audio_path), fp16=False, verbose=False)
|
| 66 |
+
return str(result.get("text", "")).strip()
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
raise RuntimeError(f"Error transcribing audio: {str(e)}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def transcribe_video(
|
| 73 |
+
video_path: str,
|
| 74 |
+
output_dir: str = "g3/data/prompt_data/audio",
|
| 75 |
+
model_name: str = "base",
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Transcribe video by extracting audio and then transcribing it.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
video_path (str): Path to the video file.
|
| 82 |
+
output_dir (str): Directory to save the audio file.
|
| 83 |
+
model_name (str): Whisper model name.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
str: Path to the saved transcription text file.
|
| 87 |
+
"""
|
| 88 |
+
audio_path = extract_audio(video_path, output_dir)
|
| 89 |
+
if not audio_path:
|
| 90 |
+
logger.error("Audio extraction failed. No audio file created.")
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
logger.info(f"Audio extracted to: {audio_path}")
|
| 94 |
+
transcript_text = transcribe_audio(audio_path, model_name=model_name)
|
| 95 |
+
|
| 96 |
+
transcript_path = Path(output_dir) / f"{Path(video_path).stem}_transcript.txt"
|
| 97 |
+
with open(transcript_path, "w", encoding="utf-8") as f:
|
| 98 |
+
f.write(transcript_text)
|
| 99 |
+
logger.info(f"Transcript saved to: {transcript_path}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def transcribe_video_directory(
|
| 103 |
+
video_dir: str,
|
| 104 |
+
output_dir: str = "g3/data/prompt_data/audio",
|
| 105 |
+
model_name: str = "base",
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Transcribe all videos in a directory.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
video_dir (str): Directory containing video files.
|
| 112 |
+
output_dir (str): Directory to save the audio and transcript files.
|
| 113 |
+
model_name (str): Whisper model name.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
None
|
| 117 |
+
"""
|
| 118 |
+
video_extensions = {".mp4", ".avi", ".mov", ".mkv"}
|
| 119 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
video_files = [
|
| 122 |
+
f
|
| 123 |
+
for f in Path(video_dir).glob("*")
|
| 124 |
+
if f.is_file() and f.suffix.lower() in video_extensions
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
if not video_files:
|
| 128 |
+
logger.info(f"No video files found in directory: {video_dir}")
|
| 129 |
+
|
| 130 |
+
for video_file in video_files:
|
| 131 |
+
logger.info(f"Processing video: {video_file}")
|
| 132 |
+
transcribe_video(str(video_file), output_dir, model_name=model_name)
|
src/prompt/search/image_search.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from threading import Lock
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from google.cloud import vision
|
| 12 |
+
from requests.adapters import HTTPAdapter
|
| 13 |
+
from urllib3.util.retry import Retry
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("uvicorn.error")
|
| 16 |
+
|
| 17 |
+
# GOOGLE CLOUD VISION API
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def annotate(path: str) -> vision.WebDetection:
|
| 21 |
+
"""Returns web annotations given the path to an image.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
path: path to the input image.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
An WebDetection object with relevant information of the
|
| 28 |
+
image from the internet (i.e., the annotations).
|
| 29 |
+
"""
|
| 30 |
+
client = vision.ImageAnnotatorClient()
|
| 31 |
+
|
| 32 |
+
if path.startswith("http") or path.startswith("gs:"):
|
| 33 |
+
image = vision.Image()
|
| 34 |
+
image.source.image_uri = path
|
| 35 |
+
|
| 36 |
+
else:
|
| 37 |
+
with open(path, "rb") as image_file:
|
| 38 |
+
content = image_file.read()
|
| 39 |
+
|
| 40 |
+
image = vision.Image(content=content)
|
| 41 |
+
|
| 42 |
+
response = client.annotate_image(
|
| 43 |
+
{
|
| 44 |
+
"image": image,
|
| 45 |
+
"features": [{"type_": vision.Feature.Type.WEB_DETECTION}],
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
return response.web_detection
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def annotate_directory(directory: str) -> list[vision.WebDetection]:
|
| 52 |
+
"""
|
| 53 |
+
Perform web detection on all image files in the given directory in batches of 16.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
directory (str): Path to the directory containing image files.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
list[vision.WebDetection]: List of WebDetection objects for each image.
|
| 60 |
+
"""
|
| 61 |
+
client = vision.ImageAnnotatorClient()
|
| 62 |
+
|
| 63 |
+
# Collect all image files first
|
| 64 |
+
image_files = []
|
| 65 |
+
for file_name in os.listdir(directory):
|
| 66 |
+
file_path = os.path.join(directory, file_name)
|
| 67 |
+
if os.path.isfile(file_path) and file_name.lower().endswith(
|
| 68 |
+
(".jpg", ".jpeg", ".png", ".bmp", ".gif")
|
| 69 |
+
):
|
| 70 |
+
image_files.append(file_path)
|
| 71 |
+
|
| 72 |
+
all_web_detections = []
|
| 73 |
+
batch_size = 16 # Google Vision API batch limit
|
| 74 |
+
|
| 75 |
+
# Process images in batches of 16
|
| 76 |
+
for i in range(0, len(image_files), batch_size):
|
| 77 |
+
batch_files = image_files[i : i + batch_size]
|
| 78 |
+
logger.info(
|
| 79 |
+
f"Processing batch {i // batch_size + 1}/{(len(image_files) + batch_size - 1) // batch_size} ({len(batch_files)} images)..."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Prepare batch requests
|
| 83 |
+
image_requests = []
|
| 84 |
+
for file_path in batch_files:
|
| 85 |
+
try:
|
| 86 |
+
with open(file_path, "rb") as image_file:
|
| 87 |
+
content = image_file.read()
|
| 88 |
+
image = vision.Image(content=content)
|
| 89 |
+
image_requests.append(image)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.warning(f"⚠️ Failed to read image {file_path}: {e}")
|
| 92 |
+
# Add a placeholder to maintain order
|
| 93 |
+
image_requests.append(None)
|
| 94 |
+
|
| 95 |
+
# Filter out None values and keep track of valid indices
|
| 96 |
+
valid_requests = []
|
| 97 |
+
valid_indices = []
|
| 98 |
+
for idx, request in enumerate(image_requests):
|
| 99 |
+
if request is not None:
|
| 100 |
+
valid_requests.append(request)
|
| 101 |
+
valid_indices.append(idx)
|
| 102 |
+
|
| 103 |
+
if not valid_requests:
|
| 104 |
+
logger.warning(f"⚠️ No valid images in batch {i // batch_size + 1}")
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
# Make batch API call
|
| 109 |
+
responses = client.batch_annotate_images(
|
| 110 |
+
requests=[
|
| 111 |
+
vision.AnnotateImageRequest(
|
| 112 |
+
image=image,
|
| 113 |
+
features=[
|
| 114 |
+
vision.Feature(type=vision.Feature.Type.WEB_DETECTION)
|
| 115 |
+
],
|
| 116 |
+
)
|
| 117 |
+
for image in valid_requests
|
| 118 |
+
]
|
| 119 |
+
).responses
|
| 120 |
+
|
| 121 |
+
# Process responses and maintain order
|
| 122 |
+
batch_detections: list[vision.WebDetection | None] = [None] * len(
|
| 123 |
+
batch_files
|
| 124 |
+
)
|
| 125 |
+
for response_idx, global_idx in enumerate(valid_indices):
|
| 126 |
+
if (
|
| 127 |
+
response_idx < len(responses)
|
| 128 |
+
and responses[response_idx].web_detection
|
| 129 |
+
):
|
| 130 |
+
batch_detections[global_idx] = responses[response_idx].web_detection
|
| 131 |
+
|
| 132 |
+
# Add to results (filter out None values)
|
| 133 |
+
all_web_detections.extend(
|
| 134 |
+
[det for det in batch_detections if det is not None]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.warning(f"⚠️ Batch {i // batch_size + 1} failed: {e}")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
logger.info(
|
| 142 |
+
f"✅ Successfully processed {len(all_web_detections)} images out of {len(image_files)} total"
|
| 143 |
+
)
|
| 144 |
+
return all_web_detections
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def parse_web_detection(annotations: vision.WebDetection) -> dict:
|
| 148 |
+
"""Returns detected features in the provided web annotations as a dict."""
|
| 149 |
+
result = {
|
| 150 |
+
"pages_with_matching_images": [],
|
| 151 |
+
"full_matching_images": [],
|
| 152 |
+
"partial_matching_images": [],
|
| 153 |
+
"web_entities": [],
|
| 154 |
+
}
|
| 155 |
+
if annotations.pages_with_matching_images:
|
| 156 |
+
for page in annotations.pages_with_matching_images:
|
| 157 |
+
result["pages_with_matching_images"].append(page.url)
|
| 158 |
+
if annotations.full_matching_images:
|
| 159 |
+
for image in annotations.full_matching_images:
|
| 160 |
+
result["full_matching_images"].append(image.url)
|
| 161 |
+
if annotations.partial_matching_images:
|
| 162 |
+
for image in annotations.partial_matching_images:
|
| 163 |
+
result["partial_matching_images"].append(image.url)
|
| 164 |
+
if annotations.web_entities:
|
| 165 |
+
for entity in annotations.web_entities:
|
| 166 |
+
result["web_entities"].append(
|
| 167 |
+
{"score": entity.score, "description": entity.description}
|
| 168 |
+
)
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_image_links_vision(annotations: vision.WebDetection) -> list[str]:
|
| 173 |
+
"""Extracts image links from web detection annotations."""
|
| 174 |
+
links = []
|
| 175 |
+
if annotations.pages_with_matching_images:
|
| 176 |
+
for page in annotations.pages_with_matching_images:
|
| 177 |
+
links.append(page.url)
|
| 178 |
+
if not links and annotations.full_matching_images:
|
| 179 |
+
# Fallback to full matching images if no pages found
|
| 180 |
+
for image in annotations.full_matching_images:
|
| 181 |
+
links.append(image.url)
|
| 182 |
+
if not links and annotations.partial_matching_images:
|
| 183 |
+
# Fallback to partial matching images if no full matches found
|
| 184 |
+
for image in annotations.partial_matching_images:
|
| 185 |
+
links.append(image.url)
|
| 186 |
+
return links
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# SCRAPING DOG API
|
| 190 |
+
def upload_image_to_imgbb(image_path: str, api_key: str) -> str:
|
| 191 |
+
"""Upload image to imgbb with automatic retry on transient errors."""
|
| 192 |
+
|
| 193 |
+
# Encode the image
|
| 194 |
+
try:
|
| 195 |
+
with open(image_path, "rb") as f:
|
| 196 |
+
image_data = base64.b64encode(f.read()).decode("utf-8")
|
| 197 |
+
except FileNotFoundError:
|
| 198 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
raise Exception(f"Error reading image file: {e}")
|
| 201 |
+
|
| 202 |
+
payload = {"key": api_key, "image": image_data}
|
| 203 |
+
imgbb_url = "https://api.imgbb.com/1/upload"
|
| 204 |
+
|
| 205 |
+
# Configure session with retry logic
|
| 206 |
+
session = requests.Session()
|
| 207 |
+
retry_strategy = Retry(
|
| 208 |
+
total=5,
|
| 209 |
+
backoff_factor=1,
|
| 210 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
| 211 |
+
allowed_methods=["POST"],
|
| 212 |
+
raise_on_status=False,
|
| 213 |
+
)
|
| 214 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
| 215 |
+
session.mount("https://", adapter)
|
| 216 |
+
session.mount("http://", adapter)
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
resp = session.post(imgbb_url, data=payload, timeout=30)
|
| 220 |
+
resp.raise_for_status()
|
| 221 |
+
result = resp.json()
|
| 222 |
+
if result.get("success"):
|
| 223 |
+
return result["data"]["url"]
|
| 224 |
+
else:
|
| 225 |
+
raise Exception(
|
| 226 |
+
f"imgbb upload failed: {result.get('error', 'Unknown error')}"
|
| 227 |
+
)
|
| 228 |
+
except requests.exceptions.RequestException as e:
|
| 229 |
+
raise Exception(f"Failed to upload after retries: {e}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def search_with_scrapingdog_lens(
|
| 233 |
+
image_path: str, imgbb_key: str, scrapingdog_key: str
|
| 234 |
+
) -> dict:
|
| 235 |
+
"""
|
| 236 |
+
Uploads an image to imgbb, then queries ScrapingDog's Google Lens API with 3 retries.
|
| 237 |
+
"""
|
| 238 |
+
try:
|
| 239 |
+
image_url = upload_image_to_imgbb(image_path, imgbb_key)
|
| 240 |
+
logger.info(f"Image uploaded to ImgBB: {image_url}")
|
| 241 |
+
|
| 242 |
+
lens_url = f"https://lens.google.com/uploadbyurl?url={image_url}"
|
| 243 |
+
params = {
|
| 244 |
+
"api_key": scrapingdog_key,
|
| 245 |
+
"url": lens_url,
|
| 246 |
+
"visual_matches": "true",
|
| 247 |
+
"exact_matches": "true",
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
# Retry logic - 3 attempts
|
| 251 |
+
for attempt in range(3):
|
| 252 |
+
try:
|
| 253 |
+
resp = requests.get(
|
| 254 |
+
"https://api.scrapingdog.com/google_lens", params=params, timeout=60
|
| 255 |
+
)
|
| 256 |
+
resp.raise_for_status()
|
| 257 |
+
return resp.json()
|
| 258 |
+
except requests.exceptions.RequestException as e:
|
| 259 |
+
logger.warning(
|
| 260 |
+
f"⚠️ ScrapingDog attempt {attempt + 1}/3 failed for {os.path.basename(image_path)}: {e}"
|
| 261 |
+
)
|
| 262 |
+
if attempt < 2: # Don't sleep on the last attempt
|
| 263 |
+
time.sleep(2) # Wait 2 seconds before retrying
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
# All retries failed
|
| 267 |
+
logger.error(
|
| 268 |
+
f"❌ All 3 ScrapingDog attempts failed for {os.path.basename(image_path)}"
|
| 269 |
+
)
|
| 270 |
+
return {"lens_results": []}
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.warning(f"⚠️ ScrapingDog API unexpected error for {image_path}: {e}")
|
| 274 |
+
return {"lens_results": []}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_image_links_scrapingdog(search_results: dict, n_results: int = 5) -> list[str]:
|
| 278 |
+
"""Get image links from Scrapingdog Lens API."""
|
| 279 |
+
return [result["link"] for result in search_results.get("lens_results", [])][
|
| 280 |
+
:n_results
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def process_scrapingdog_only(image_path: str) -> list[str]:
|
| 285 |
+
"""Process a single image with ScrapingDog API only."""
|
| 286 |
+
try:
|
| 287 |
+
scrapingdog_search_result = search_with_scrapingdog_lens(
|
| 288 |
+
image_path=image_path,
|
| 289 |
+
imgbb_key=os.environ["IMGBB_API_KEY"],
|
| 290 |
+
scrapingdog_key=os.environ["SCRAPINGDOG_API_KEY"],
|
| 291 |
+
)
|
| 292 |
+
scrapingdog_result = get_image_links_scrapingdog(
|
| 293 |
+
scrapingdog_search_result, n_results=5
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
with print_lock:
|
| 297 |
+
logger.info(
|
| 298 |
+
f"✅ ScrapingDog completed for {os.path.basename(image_path)} - {len(scrapingdog_result)} links"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return scrapingdog_result
|
| 302 |
+
except Exception as e:
|
| 303 |
+
with print_lock:
|
| 304 |
+
logger.error(
|
| 305 |
+
f"❌ ScrapingDog error for {os.path.basename(image_path)}: {e}"
|
| 306 |
+
)
|
| 307 |
+
return []
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Thread-safe print lock
|
| 311 |
+
print_lock = Lock()
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def process_single_image(image_path: str, imgbb_key: str, scrapingdog_key: str) -> dict:
|
| 315 |
+
"""
|
| 316 |
+
Process a single image with both Vision API and ScrapingDog API.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
image_path: Path to the image file
|
| 320 |
+
imgbb_key: ImgBB API key
|
| 321 |
+
scrapingdog_key: ScrapingDog API key
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Dictionary containing the results for this image
|
| 325 |
+
"""
|
| 326 |
+
try:
|
| 327 |
+
# Vision API processing
|
| 328 |
+
annotations = annotate(image_path)
|
| 329 |
+
vision_result = get_image_links_vision(annotations)
|
| 330 |
+
|
| 331 |
+
# ScrapingDog API processing
|
| 332 |
+
scrapingdog_search_result = search_with_scrapingdog_lens(
|
| 333 |
+
image_path=image_path, imgbb_key=imgbb_key, scrapingdog_key=scrapingdog_key
|
| 334 |
+
)
|
| 335 |
+
scrapingdog_result = get_image_links_scrapingdog(
|
| 336 |
+
scrapingdog_search_result, n_results=5
|
| 337 |
+
)
|
| 338 |
+
# scrapingdog_result = []
|
| 339 |
+
|
| 340 |
+
result = {
|
| 341 |
+
"image_path": os.path.basename(image_path),
|
| 342 |
+
"vision_result": vision_result,
|
| 343 |
+
"scrapingdog_result": scrapingdog_result,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
with print_lock:
|
| 347 |
+
logger.info(f"✅ Completed processing {os.path.basename(image_path)}")
|
| 348 |
+
|
| 349 |
+
return result
|
| 350 |
+
|
| 351 |
+
except Exception as e:
|
| 352 |
+
with print_lock:
|
| 353 |
+
logger.error(f"❌ Error processing {os.path.basename(image_path)}: {e}")
|
| 354 |
+
return {
|
| 355 |
+
"image_path": os.path.basename(image_path),
|
| 356 |
+
"vision_result": [],
|
| 357 |
+
"scrapingdog_result": [],
|
| 358 |
+
"error": str(e),
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def image_search_directory(
|
| 363 |
+
directory: str,
|
| 364 |
+
output_dir: str = "g3/data/prompt_data",
|
| 365 |
+
filename: str = "metadata.json",
|
| 366 |
+
imgbb_key: str = "YOUR_IMGBB_API_KEY",
|
| 367 |
+
scrapingdog_key: str = "YOUR_SCRAPINGDOG_API_KEY",
|
| 368 |
+
max_workers: int = 4,
|
| 369 |
+
target_links: int = 20,
|
| 370 |
+
) -> None:
|
| 371 |
+
"""
|
| 372 |
+
Perform web detection with a two-phase approach:
|
| 373 |
+
1. Run Vision API on all images first using annotate_directory
|
| 374 |
+
2. If total unique links < target_links, run ScrapingDog on images until target is reached
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
directory (str): Path to the directory containing image files.
|
| 378 |
+
output_dir (str): Directory to save the JSON output.
|
| 379 |
+
filename (str): Name of the JSON file to save the results.
|
| 380 |
+
imgbb_key (str): ImgBB API key for image uploading.
|
| 381 |
+
scrapingdog_key (str): ScrapingDog API key for lens search.
|
| 382 |
+
max_workers (int): Maximum number of parallel workers.
|
| 383 |
+
target_links (int): Target number of unique links to collect.
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
None
|
| 387 |
+
"""
|
| 388 |
+
EXCLUDE_DOMAIN = [
|
| 389 |
+
"youtube.com",
|
| 390 |
+
]
|
| 391 |
+
# Get all image files
|
| 392 |
+
image_files = []
|
| 393 |
+
for file_name in os.listdir(directory):
|
| 394 |
+
file_path = os.path.join(directory, file_name)
|
| 395 |
+
if os.path.isfile(file_path) and file_name.lower().endswith(
|
| 396 |
+
(".jpg", ".jpeg", ".png", ".bmp", ".gif")
|
| 397 |
+
):
|
| 398 |
+
image_files.append(file_path)
|
| 399 |
+
|
| 400 |
+
if not image_files:
|
| 401 |
+
logger.info("No image files found in the directory.")
|
| 402 |
+
return
|
| 403 |
+
|
| 404 |
+
logger.info(
|
| 405 |
+
f"Found {len(image_files)} image files. Target: {target_links} unique links"
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Phase 1: Run Vision API on all images using annotate_directory
|
| 409 |
+
logger.info("🔍 Phase 1: Running Vision API on all images...")
|
| 410 |
+
all_links = set()
|
| 411 |
+
vision_links_count = 0
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
# Use the existing annotate_directory function for batch processing
|
| 415 |
+
web_detections = annotate_directory(directory)
|
| 416 |
+
|
| 417 |
+
# Extract links from all web detections
|
| 418 |
+
for detection in web_detections:
|
| 419 |
+
links = get_image_links_vision(detection)
|
| 420 |
+
# Filter out links from excluded domains
|
| 421 |
+
links = [
|
| 422 |
+
link for link in links if not any(domain in link for domain in EXCLUDE_DOMAIN)
|
| 423 |
+
]
|
| 424 |
+
all_links.update(links) # Add to set (automatically deduplicates)
|
| 425 |
+
|
| 426 |
+
vision_links_count = len(all_links)
|
| 427 |
+
logger.info(
|
| 428 |
+
f"✅ Phase 1 complete: {vision_links_count} unique links from Vision API"
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.error(f"❌ Vision API processing failed: {e}")
|
| 433 |
+
all_links = set()
|
| 434 |
+
vision_links_count = 0
|
| 435 |
+
|
| 436 |
+
# Phase 2: Run ScrapingDog if needed
|
| 437 |
+
scrapingdog_links_count = 0
|
| 438 |
+
|
| 439 |
+
if len(all_links) < target_links:
|
| 440 |
+
needed_links = target_links - len(all_links)
|
| 441 |
+
logger.info(
|
| 442 |
+
f"🔍 Phase 2: Need {needed_links} more links. Running ScrapingDog..."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Check if API keys are available
|
| 446 |
+
if (
|
| 447 |
+
imgbb_key == "YOUR_IMGBB_API_KEY"
|
| 448 |
+
or scrapingdog_key == "YOUR_SCRAPINGDOG_API_KEY"
|
| 449 |
+
):
|
| 450 |
+
logger.warning("⚠️ ScrapingDog API keys not available. Skipping Phase 2.")
|
| 451 |
+
else:
|
| 452 |
+
scrapingdog_completed = 0
|
| 453 |
+
|
| 454 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 455 |
+
# Submit ScrapingDog tasks for all images
|
| 456 |
+
future_to_image = {
|
| 457 |
+
executor.submit(process_scrapingdog_only, image_path): image_path
|
| 458 |
+
for image_path in image_files
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
# Collect ScrapingDog results until we have enough links
|
| 462 |
+
for future in as_completed(future_to_image):
|
| 463 |
+
image_path = future_to_image[future]
|
| 464 |
+
try:
|
| 465 |
+
result_links = future.result()
|
| 466 |
+
filtered_links = [
|
| 467 |
+
link for link in result_links if not any(domain in link for domain in EXCLUDE_DOMAIN)
|
| 468 |
+
]
|
| 469 |
+
initial_count = len(filtered_links)
|
| 470 |
+
all_links.update(filtered_links) # Add new links to the main set
|
| 471 |
+
scrapingdog_links_count += (
|
| 472 |
+
len(all_links) - initial_count
|
| 473 |
+
) # Count new unique links added
|
| 474 |
+
scrapingdog_completed += 1
|
| 475 |
+
|
| 476 |
+
with print_lock:
|
| 477 |
+
logger.info(
|
| 478 |
+
f"ScrapingDog Progress: {scrapingdog_completed}/{len(image_files)} images, "
|
| 479 |
+
f"{scrapingdog_links_count} new ScrapingDog links, {len(all_links)} total unique"
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Stop early if we have enough links
|
| 483 |
+
if len(all_links) >= target_links:
|
| 484 |
+
logger.info(
|
| 485 |
+
f"🎯 Target reached! {len(all_links)} >= {target_links} links"
|
| 486 |
+
)
|
| 487 |
+
# Cancel remaining futures
|
| 488 |
+
for remaining_future in future_to_image:
|
| 489 |
+
if not remaining_future.done():
|
| 490 |
+
remaining_future.cancel()
|
| 491 |
+
break
|
| 492 |
+
|
| 493 |
+
except Exception as e:
|
| 494 |
+
with print_lock:
|
| 495 |
+
logger.error(
|
| 496 |
+
f"❌ Failed ScrapingDog for {os.path.basename(image_path)}: {e}"
|
| 497 |
+
)
|
| 498 |
+
scrapingdog_completed += 1
|
| 499 |
+
|
| 500 |
+
# Prepare final results
|
| 501 |
+
total_unique_links = len(all_links)
|
| 502 |
+
all_links = list(all_links)[:target_links]
|
| 503 |
+
results = {
|
| 504 |
+
"all_links": all_links,
|
| 505 |
+
"total_unique_links": total_unique_links,
|
| 506 |
+
"target_achieved": total_unique_links >= target_links,
|
| 507 |
+
"summary": {
|
| 508 |
+
"images_processed": len(image_files),
|
| 509 |
+
"vision_links": vision_links_count,
|
| 510 |
+
"scrapingdog_links": scrapingdog_links_count,
|
| 511 |
+
"total_unique_links": total_unique_links,
|
| 512 |
+
"target_links": target_links,
|
| 513 |
+
},
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
# Ensure the output directory exists
|
| 517 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 518 |
+
|
| 519 |
+
# Save results to JSON file
|
| 520 |
+
out_path = Path(output_dir) / filename
|
| 521 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 522 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 523 |
+
|
| 524 |
+
logger.info(
|
| 525 |
+
f"✅ Saved results to {out_path}\n"
|
| 526 |
+
f"📊 Summary: {vision_links_count} Vision + {scrapingdog_links_count} ScrapingDog = {total_unique_links} total unique links"
|
| 527 |
+
)
|
src/prompt/search/index_search.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from threading import Lock
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger("uvicorn.error")
|
| 14 |
+
|
| 15 |
+
# Thread-safe lock for logging
|
| 16 |
+
print_lock = Lock()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def search_index(model, rgb_image, device, index, top_k=20):
|
| 20 |
+
"""
|
| 21 |
+
Search FAISS index for similar and dissimilar coordinates using image embeddings.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model: Vision model used for embedding generation.
|
| 25 |
+
rgb_image: PIL RGB Image.
|
| 26 |
+
device: Device to run the model on (e.g., "cuda" or "cpu").
|
| 27 |
+
index: FAISS index for searching.
|
| 28 |
+
top_k (int): Number of top results to return.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
tuple: (D, I, D_reverse, I_reverse) - distances and indices for positive and negative embeddings.
|
| 32 |
+
"""
|
| 33 |
+
# logger.info("Searching FAISS index...")
|
| 34 |
+
image = model.vision_processor(images=rgb_image, return_tensors="pt")[
|
| 35 |
+
"pixel_values"
|
| 36 |
+
].reshape(-1, 224, 224)
|
| 37 |
+
image = image.unsqueeze(0).to(device) # Add batch dimension
|
| 38 |
+
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
vision_output = model.vision_model(image)[1]
|
| 41 |
+
image_embeds = model.vision_projection(vision_output)
|
| 42 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 43 |
+
|
| 44 |
+
image_text_embeds = model.vision_projection_else_1(
|
| 45 |
+
model.vision_projection(vision_output)
|
| 46 |
+
)
|
| 47 |
+
image_text_embeds = image_text_embeds / image_text_embeds.norm(
|
| 48 |
+
p=2, dim=-1, keepdim=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
image_location_embeds = model.vision_projection_else_2(
|
| 52 |
+
model.vision_projection(vision_output)
|
| 53 |
+
)
|
| 54 |
+
image_location_embeds = image_location_embeds / image_location_embeds.norm(
|
| 55 |
+
p=2, dim=-1, keepdim=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
positive_image_embeds = torch.cat(
|
| 59 |
+
[image_embeds, image_text_embeds, image_location_embeds], dim=1
|
| 60 |
+
)
|
| 61 |
+
positive_image_embeds = (
|
| 62 |
+
positive_image_embeds.cpu().detach().numpy().astype(np.float32)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
negative_image_embeds = positive_image_embeds * (-1.0)
|
| 66 |
+
|
| 67 |
+
# Search FAISS index
|
| 68 |
+
D, I = index.search(positive_image_embeds, top_k)
|
| 69 |
+
D_reverse, I_reverse = index.search(negative_image_embeds, top_k)
|
| 70 |
+
return D, I, D_reverse, I_reverse
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_gps_coordinates(I, I_reverse, database_csv_path):
|
| 74 |
+
"""
|
| 75 |
+
Helper method to get GPS coordinates from database using FAISS indices.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
I: FAISS indices for positive embeddings
|
| 79 |
+
I_reverse: FAISS indices for negative embeddings
|
| 80 |
+
database_csv_path (str): Path to GPS coordinates database CSV
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
tuple: (candidates_gps, reverse_gps) - lists of (lat, lon) tuples
|
| 84 |
+
"""
|
| 85 |
+
if I is None or I_reverse is None:
|
| 86 |
+
return [], []
|
| 87 |
+
|
| 88 |
+
candidate_indices = I[0]
|
| 89 |
+
reverse_indices = I_reverse[0]
|
| 90 |
+
|
| 91 |
+
candidates_gps = []
|
| 92 |
+
reverse_gps = []
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
for chunk in pd.read_csv(
|
| 96 |
+
database_csv_path, chunksize=10000, usecols=["LAT", "LON"]
|
| 97 |
+
):
|
| 98 |
+
for idx in candidate_indices:
|
| 99 |
+
if idx in chunk.index:
|
| 100 |
+
lat = float(chunk.loc[idx, "LAT"])
|
| 101 |
+
lon = float(chunk.loc[idx, "LON"])
|
| 102 |
+
candidates_gps.append((lat, lon))
|
| 103 |
+
|
| 104 |
+
for ridx in reverse_indices:
|
| 105 |
+
if ridx in chunk.index:
|
| 106 |
+
lat = float(chunk.loc[ridx, "LAT"])
|
| 107 |
+
lon = float(chunk.loc[ridx, "LON"])
|
| 108 |
+
reverse_gps.append((lat, lon))
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"⚠️ Error loading GPS coordinates from database: {e}")
|
| 111 |
+
|
| 112 |
+
return candidates_gps, reverse_gps
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def save_results_to_json(candidates_gps: list, reverse_gps: list, output_path: str):
|
| 116 |
+
"""
|
| 117 |
+
Save search results to a JSON file.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
results (dict): Search results to save.
|
| 121 |
+
output_path (str): Path to the output JSON file.
|
| 122 |
+
"""
|
| 123 |
+
results = {"candidates_gps": candidates_gps, "reverse_gps": reverse_gps}
|
| 124 |
+
with open(output_path, "w") as json_file:
|
| 125 |
+
json.dump(results, json_file, indent=4)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def process_single_image(image_path, model, device, index, database_csv_path, top_k=20):
|
| 129 |
+
"""
|
| 130 |
+
Process a single image for index search.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
image_path: Path to the image file
|
| 134 |
+
model: Vision model used for embedding generation
|
| 135 |
+
device: Device to run the model on
|
| 136 |
+
index: FAISS index for searching
|
| 137 |
+
database_csv_path: Path to GPS coordinates database CSV
|
| 138 |
+
top_k: Number of top results to return
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
tuple: (candidates_gps, reverse_gps) for this image
|
| 142 |
+
"""
|
| 143 |
+
try:
|
| 144 |
+
rgb_image = Image.open(image_path).convert("RGB")
|
| 145 |
+
D, I, D_reverse, I_reverse = search_index(
|
| 146 |
+
model, rgb_image, device, index, top_k
|
| 147 |
+
)
|
| 148 |
+
candidates_gps, reverse_gps = get_gps_coordinates(
|
| 149 |
+
I, I_reverse, database_csv_path
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# with print_lock:
|
| 153 |
+
# logger.info(
|
| 154 |
+
# f"✅ Processed {os.path.basename(image_path)}: {len(candidates_gps)} candidates, {len(reverse_gps)} reverse"
|
| 155 |
+
# )
|
| 156 |
+
|
| 157 |
+
return candidates_gps, reverse_gps
|
| 158 |
+
except Exception as e:
|
| 159 |
+
with print_lock:
|
| 160 |
+
logger.error(f"❌ Error processing {os.path.basename(image_path)}: {e}")
|
| 161 |
+
return [], []
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def search_index_directory(
|
| 165 |
+
model,
|
| 166 |
+
device,
|
| 167 |
+
index,
|
| 168 |
+
image_dir,
|
| 169 |
+
database_csv_path,
|
| 170 |
+
top_k=20,
|
| 171 |
+
max_elements=20,
|
| 172 |
+
max_workers=4,
|
| 173 |
+
):
|
| 174 |
+
"""
|
| 175 |
+
Perform FAISS index search for all images in a directory in parallel and gradually build a prioritized set of candidates.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
model: Vision model used for embedding generation.
|
| 179 |
+
device: Device to run the model on (e.g., "cuda" or "cpu").
|
| 180 |
+
index: FAISS index for searching.
|
| 181 |
+
image_dir (str): Path to the directory containing images.
|
| 182 |
+
database_csv_path (str): Path to GPS coordinates database CSV.
|
| 183 |
+
top_k (int): Number of top results to return for each image.
|
| 184 |
+
max_elements (int): Maximum number of elements in the final candidates set.
|
| 185 |
+
max_workers (int): Maximum number of parallel workers.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
tuple: (candidates_gps, reverse_gps) - lists of (lat, lon) tuples.
|
| 189 |
+
"""
|
| 190 |
+
# Get all image paths
|
| 191 |
+
image_paths = [
|
| 192 |
+
Path(image_dir) / img
|
| 193 |
+
for img in os.listdir(image_dir)
|
| 194 |
+
if img.lower().endswith((".jpg", ".jpeg", ".png", ".bmp"))
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
if not image_paths:
|
| 198 |
+
logger.warning("No images found in directory")
|
| 199 |
+
return [], []
|
| 200 |
+
|
| 201 |
+
logger.info(
|
| 202 |
+
f"🚀 Processing {len(image_paths)} images with {max_workers} parallel workers..."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
all_candidates_gps = []
|
| 206 |
+
all_reverse_gps = []
|
| 207 |
+
completed_count = 0
|
| 208 |
+
|
| 209 |
+
# Process images in parallel
|
| 210 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 211 |
+
# Submit all tasks
|
| 212 |
+
future_to_path = {
|
| 213 |
+
executor.submit(
|
| 214 |
+
process_single_image,
|
| 215 |
+
image_path,
|
| 216 |
+
model,
|
| 217 |
+
device,
|
| 218 |
+
index,
|
| 219 |
+
database_csv_path,
|
| 220 |
+
top_k,
|
| 221 |
+
): image_path
|
| 222 |
+
for image_path in image_paths
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# Collect results as they complete
|
| 226 |
+
for future in as_completed(future_to_path):
|
| 227 |
+
image_path = future_to_path[future]
|
| 228 |
+
try:
|
| 229 |
+
candidates_gps, reverse_gps = future.result()
|
| 230 |
+
all_candidates_gps.append(candidates_gps)
|
| 231 |
+
all_reverse_gps.append(reverse_gps)
|
| 232 |
+
completed_count += 1
|
| 233 |
+
|
| 234 |
+
with print_lock:
|
| 235 |
+
logger.info(
|
| 236 |
+
f"Progress: {completed_count}/{len(image_paths)} images completed"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
with print_lock:
|
| 241 |
+
logger.error(
|
| 242 |
+
f"❌ Failed to process {os.path.basename(image_path)}: {e}"
|
| 243 |
+
)
|
| 244 |
+
# Add empty results for failed images
|
| 245 |
+
all_candidates_gps.append([])
|
| 246 |
+
all_reverse_gps.append([])
|
| 247 |
+
completed_count += 1
|
| 248 |
+
|
| 249 |
+
# Build prioritized sets from all results
|
| 250 |
+
candidates_gps = set()
|
| 251 |
+
reverse_gps = set()
|
| 252 |
+
|
| 253 |
+
for priority in range(top_k):
|
| 254 |
+
for image_candidates_gps, image_reverse_gps in zip(
|
| 255 |
+
all_candidates_gps, all_reverse_gps
|
| 256 |
+
):
|
| 257 |
+
if len(candidates_gps) < max_elements and priority < len(
|
| 258 |
+
image_candidates_gps
|
| 259 |
+
):
|
| 260 |
+
candidates_gps.add(image_candidates_gps[priority])
|
| 261 |
+
if len(reverse_gps) < max_elements and priority < len(image_reverse_gps):
|
| 262 |
+
reverse_gps.add(image_reverse_gps[priority])
|
| 263 |
+
|
| 264 |
+
if len(candidates_gps) >= max_elements and len(reverse_gps) >= max_elements:
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
logger.info(
|
| 268 |
+
f"🎯 Final results: {len(candidates_gps)} candidates, {len(reverse_gps)} reverse GPS coordinates"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return list(candidates_gps), list(reverse_gps)
|
src/prompt/search/text_search.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import httpx
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("uvicorn.error")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def retry_request(func, max_retries=3, base_delay=2.0):
|
| 14 |
+
"""
|
| 15 |
+
Retry a function with exponential backoff for timeout and connection errors.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
func: Function to retry
|
| 19 |
+
max_retries: Maximum number of retry attempts
|
| 20 |
+
base_delay: Base delay for exponential backoff
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Result of the function call
|
| 24 |
+
|
| 25 |
+
Raises:
|
| 26 |
+
Last exception if all retries fail
|
| 27 |
+
"""
|
| 28 |
+
for attempt in range(max_retries):
|
| 29 |
+
try:
|
| 30 |
+
return func()
|
| 31 |
+
except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.TimeoutException) as e:
|
| 32 |
+
if attempt < max_retries - 1:
|
| 33 |
+
delay = base_delay * (2**attempt)
|
| 34 |
+
logger.warning(
|
| 35 |
+
f"⚠️ Timeout error (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s..."
|
| 36 |
+
)
|
| 37 |
+
time.sleep(delay)
|
| 38 |
+
continue
|
| 39 |
+
else:
|
| 40 |
+
logger.error(
|
| 41 |
+
f"❌ Max retries ({max_retries}) exceeded for timeout error."
|
| 42 |
+
)
|
| 43 |
+
raise e
|
| 44 |
+
except httpx.HTTPStatusError as e:
|
| 45 |
+
if e.response.status_code in [500, 502, 503, 504]:
|
| 46 |
+
if attempt < max_retries - 1:
|
| 47 |
+
delay = base_delay * (2**attempt)
|
| 48 |
+
logger.warning(
|
| 49 |
+
f"⚠️ Server error {e.response.status_code} (attempt {attempt + 1}/{max_retries}). Retrying in {delay}s..."
|
| 50 |
+
)
|
| 51 |
+
time.sleep(delay)
|
| 52 |
+
continue
|
| 53 |
+
logger.error(f"❌ HTTP error {e.response.status_code}: {e}")
|
| 54 |
+
raise e
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"❌ Unexpected error: {e}")
|
| 57 |
+
raise e
|
| 58 |
+
|
| 59 |
+
# Should never reach here
|
| 60 |
+
raise RuntimeError("Retry logic failed")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def extension_from_content_type(content_type: str) -> str:
|
| 64 |
+
# Define allowed image types
|
| 65 |
+
allowed_types = {
|
| 66 |
+
"image/png": "png",
|
| 67 |
+
"image/jpeg": "jpg",
|
| 68 |
+
"image/jpg": "jpg",
|
| 69 |
+
"image/webp": "webp",
|
| 70 |
+
"image/heic": "heic",
|
| 71 |
+
"image/heif": "heif",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Normalize content type (remove charset, etc.)
|
| 75 |
+
content_type = content_type.split(";")[0].strip().lower()
|
| 76 |
+
|
| 77 |
+
if content_type in allowed_types:
|
| 78 |
+
return allowed_types[content_type]
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Content type '{content_type}' is not supported. Allowed types: {list(allowed_types.keys())}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def text_search_image(
|
| 86 |
+
query: str,
|
| 87 |
+
num_images: int = 5,
|
| 88 |
+
api_key: str | None = None,
|
| 89 |
+
cx: str | None = None,
|
| 90 |
+
output_dir: str = "g3/data/prompt_data/images",
|
| 91 |
+
start_index: int = 0,
|
| 92 |
+
) -> list[str]:
|
| 93 |
+
if not api_key or not cx:
|
| 94 |
+
raise ValueError("GOOGLE_CLOUD_API_KEY or GOOGLE_CSE_CX not set.")
|
| 95 |
+
|
| 96 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 97 |
+
downloaded_files: list[str] = []
|
| 98 |
+
start: int = 1
|
| 99 |
+
|
| 100 |
+
idx = start_index
|
| 101 |
+
while len(downloaded_files) < num_images:
|
| 102 |
+
params = {
|
| 103 |
+
"q": query,
|
| 104 |
+
"searchType": "image",
|
| 105 |
+
"cx": cx,
|
| 106 |
+
"key": api_key,
|
| 107 |
+
"num": min(10, num_images - len(downloaded_files)),
|
| 108 |
+
"start": start,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Use retry logic for the API request
|
| 112 |
+
try:
|
| 113 |
+
response = retry_request(
|
| 114 |
+
lambda: httpx.get(
|
| 115 |
+
"https://customsearch.googleapis.com/customsearch/v1",
|
| 116 |
+
params=params,
|
| 117 |
+
timeout=30.0, # Increased timeout
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
response.raise_for_status()
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"❌ Failed to search for images after retries: {e}")
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
results = response.json().get("items", [])
|
| 126 |
+
|
| 127 |
+
if not results:
|
| 128 |
+
logger.info("No more results from API")
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
for item in results:
|
| 132 |
+
img_url: str | None = item.get("link")
|
| 133 |
+
if not img_url:
|
| 134 |
+
continue
|
| 135 |
+
try:
|
| 136 |
+
# Use retry logic for image download
|
| 137 |
+
r = retry_request(lambda url=img_url: httpx.get(url, timeout=15.0))
|
| 138 |
+
r.raise_for_status()
|
| 139 |
+
content_type = r.headers.get("Content-Type", "")
|
| 140 |
+
|
| 141 |
+
# Check if content type is supported before processing
|
| 142 |
+
try:
|
| 143 |
+
ext = extension_from_content_type(content_type)
|
| 144 |
+
except ValueError as e:
|
| 145 |
+
logger.info(f"Skipping {img_url}: {e}")
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
filename = os.path.join(output_dir, f"image_{idx:03d}.{ext}")
|
| 149 |
+
with open(filename, "wb") as f:
|
| 150 |
+
f.write(r.content)
|
| 151 |
+
downloaded_files.append(filename)
|
| 152 |
+
idx += 1
|
| 153 |
+
if len(downloaded_files) >= num_images:
|
| 154 |
+
break
|
| 155 |
+
except httpx.HTTPError as e:
|
| 156 |
+
logger.error(f"HTTP error downloading {img_url}: {e}")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error(f"Failed to download {img_url}: {e}")
|
| 159 |
+
|
| 160 |
+
start += 10
|
| 161 |
+
|
| 162 |
+
return downloaded_files
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def text_search_link(
|
| 166 |
+
query: str,
|
| 167 |
+
output_dir: str = "g3/data/prompt_data",
|
| 168 |
+
filename: str = "text_search.json",
|
| 169 |
+
num_results: int = 10,
|
| 170 |
+
api_key: Optional[str] = None,
|
| 171 |
+
cx: Optional[str] = None,
|
| 172 |
+
) -> str:
|
| 173 |
+
"""
|
| 174 |
+
Search for web links using Google Custom Search API and save results to JSON file.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
query (str): Search query string
|
| 178 |
+
output_dir (str): Directory to save the results file
|
| 179 |
+
filename (str): Name of the JSON file to save results
|
| 180 |
+
num_results (int): Number of search results to retrieve (max 100)
|
| 181 |
+
api_key (Optional[str]): Google API key, defaults to environment variable
|
| 182 |
+
cx (Optional[str]): Custom Search Engine ID, defaults to environment variable
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
str: Path to the saved JSON file
|
| 186 |
+
|
| 187 |
+
Raises:
|
| 188 |
+
ValueError: If API key or CX not provided
|
| 189 |
+
httpx.HTTPError: If API request fails
|
| 190 |
+
"""
|
| 191 |
+
if not api_key:
|
| 192 |
+
api_key = os.getenv("GOOGLE_CLOUD_API_KEY")
|
| 193 |
+
if not cx:
|
| 194 |
+
cx = os.getenv("GOOGLE_CSE_CX")
|
| 195 |
+
|
| 196 |
+
if not api_key or not cx:
|
| 197 |
+
raise ValueError("GOOGLE_CLOUD_API_KEY or GOOGLE_CSE_CX not set.")
|
| 198 |
+
|
| 199 |
+
links = []
|
| 200 |
+
start = 1
|
| 201 |
+
if not query:
|
| 202 |
+
# Prepare final results with metadata
|
| 203 |
+
search_results = {"query": query, "links": links}
|
| 204 |
+
|
| 205 |
+
# Save results to JSON file
|
| 206 |
+
output_path = os.path.join(output_dir, filename)
|
| 207 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 208 |
+
json.dump(search_results, f, indent=2, ensure_ascii=False)
|
| 209 |
+
|
| 210 |
+
logger.info(f"✅ Saved {len(links)} search results to: {output_path}")
|
| 211 |
+
return output_path
|
| 212 |
+
|
| 213 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 214 |
+
# Google Custom Search API allows max 10 results per request
|
| 215 |
+
while len(links) < num_results:
|
| 216 |
+
remaining = num_results - len(links)
|
| 217 |
+
current_num = min(10, remaining)
|
| 218 |
+
|
| 219 |
+
params = {
|
| 220 |
+
"q": query,
|
| 221 |
+
"cx": cx,
|
| 222 |
+
"key": api_key,
|
| 223 |
+
"num": current_num,
|
| 224 |
+
"start": start,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
response = retry_request(
|
| 229 |
+
lambda: httpx.get(
|
| 230 |
+
"https://customsearch.googleapis.com/customsearch/v1",
|
| 231 |
+
params=params,
|
| 232 |
+
timeout=30.0,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
response.raise_for_status()
|
| 236 |
+
data = response.json()
|
| 237 |
+
|
| 238 |
+
items = data.get("items", [])
|
| 239 |
+
if not items:
|
| 240 |
+
logger.info(
|
| 241 |
+
f"No more results available. Retrieved {len(links)} results."
|
| 242 |
+
)
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
links.extend([item.get("link", "") for item in items if "link" in item])
|
| 246 |
+
|
| 247 |
+
if len(links) >= num_results:
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
except httpx.HTTPError as e:
|
| 251 |
+
logger.error(f"HTTP error during search: {e}")
|
| 252 |
+
break
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.error(f"Error during search: {e}")
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
start += 10
|
| 258 |
+
|
| 259 |
+
# Ensure we only take the first num_results links
|
| 260 |
+
links = links[:num_results]
|
| 261 |
+
|
| 262 |
+
# Prepare final results with metadata
|
| 263 |
+
search_results = {"query": query, "links": links}
|
| 264 |
+
|
| 265 |
+
# Save results to JSON file
|
| 266 |
+
output_path = os.path.join(output_dir, filename)
|
| 267 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 268 |
+
json.dump(search_results, f, indent=2, ensure_ascii=False)
|
| 269 |
+
|
| 270 |
+
logger.info(f"✅ Saved {len(links)} search results to: {output_path}")
|
| 271 |
+
return output_path
|
src/prompt/template.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DIVERSIFICATION_PROMPT = """
|
| 2 |
+
You are an expert in geo-localization. Analyze the image and determine the most precise possible location—ideally identifying the exact building, landmark, or facility, not just the city.
|
| 3 |
+
Examine all provided content links in detail, using both textual and visual clues to support your conclusion.
|
| 4 |
+
Use only the provided links for evidence. Any additional links must directly support specific visual observations (e.g., satellite imagery or publicly available street-level photos of the same location).
|
| 5 |
+
Return your final answer as geographic coordinates.
|
| 6 |
+
|
| 7 |
+
{prompt_data}
|
| 8 |
+
|
| 9 |
+
Respond with **only** the following JSON structure (no extra text, markdown, or comments):
|
| 10 |
+
|
| 11 |
+
{{
|
| 12 |
+
"latitude": float,
|
| 13 |
+
"longitude": float,
|
| 14 |
+
"location": string,
|
| 15 |
+
"evidence": [
|
| 16 |
+
{{
|
| 17 |
+
"analysis": string,
|
| 18 |
+
"references": [string, …]
|
| 19 |
+
}}
|
| 20 |
+
]
|
| 21 |
+
}}
|
| 22 |
+
|
| 23 |
+
**Guidelines:**
|
| 24 |
+
- One entry per clue (visual and textual).
|
| 25 |
+
- Each object in the "evidence" list should explain a single textual or visual clue and be as many as possible. All image in the prompt follow the format: "image_{{idx:03d}}.jpg", starting from image_000.jpg.
|
| 26 |
+
- In the "references" list, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
|
| 27 |
+
- The "analysis" field must describe the clue and cite reference in its corresponding "references" using bracketed indices like [1], [2], etc. The corresponding URLs or images for those references must be included in the "references" list for that object.
|
| 28 |
+
+ For contextual evidence, must cite textual/news URLs.
|
| 29 |
+
+ For visual clues, cite `image_{{idx:03d}}.jpg` in `references` and any satellite/map URLs as needed.
|
| 30 |
+
- MUST use given links to support the analysis.
|
| 31 |
+
- If you can’t identify a specific building, give the city‑center coordinates.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
LOCATION_PROMPT = """
|
| 35 |
+
Location: {location}
|
| 36 |
+
|
| 37 |
+
Your task is to determine the geographic coordinates (latitude and longitude) of the specified location by following these steps:
|
| 38 |
+
|
| 39 |
+
1. Attempt to find the exact GPS coordinates using reliable online sources such as maps or satellite imagery.
|
| 40 |
+
|
| 41 |
+
2. If the exact location is not available, find the coordinates of a nearby or adjacent place (e.g., a recognizable landmark, building, road, or intersection).
|
| 42 |
+
|
| 43 |
+
3. If no specific nearby location can be found, use the coordinates of the broader area (e.g., the center of Khan Younis or Gaza).
|
| 44 |
+
|
| 45 |
+
4. In the "references" list, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
|
| 46 |
+
|
| 47 |
+
Return your answer in the following JSON format:
|
| 48 |
+
|
| 49 |
+
{{
|
| 50 |
+
"latitude": float,
|
| 51 |
+
"longitude": float,
|
| 52 |
+
"analysis": "Describe how the coordinates were identified or approximated, including any visual or textual clues used.",
|
| 53 |
+
"references": ["URL1", "URL2", ...]
|
| 54 |
+
}}
|
| 55 |
+
|
| 56 |
+
- The "analysis" must clearly explain the reasoning behind the chosen coordinates.
|
| 57 |
+
- The "references" list must include all URLs cited in the analysis.
|
| 58 |
+
- Do not include any text outside of the JSON structure.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
VERIFICATION_PROMPT = """
|
| 62 |
+
You are an expert in multimedia verification. Analyze the provided content and decide if it’s authentic or fabricated. Support your conclusion with detailed, verifiable evidence.
|
| 63 |
+
|
| 64 |
+
{prompt_data}
|
| 65 |
+
|
| 66 |
+
Prediction to verify:
|
| 67 |
+
{prediction}
|
| 68 |
+
|
| 69 |
+
Guidelines:
|
| 70 |
+
1. Output only a JSON object with these fields:
|
| 71 |
+
{{
|
| 72 |
+
"latitude": float,
|
| 73 |
+
"longitude": float,
|
| 74 |
+
"location": string,
|
| 75 |
+
"evidence": [
|
| 76 |
+
{{
|
| 77 |
+
"analysis": string,
|
| 78 |
+
"references": [string, …]
|
| 79 |
+
}}
|
| 80 |
+
]
|
| 81 |
+
}}
|
| 82 |
+
|
| 83 |
+
2. Images are named “image_{{idx:03d}}.jpg”:
|
| 84 |
+
- Images up to “image_{satellite_image_id}.jpg” were used to generate the prediction.
|
| 85 |
+
- “image_{satellite_image_id}.jpg” is the satellite reference.
|
| 86 |
+
- Images after that show the claimed location’s landmarks—use them only to confirm buildings or landmarks.
|
| 87 |
+
|
| 88 |
+
3. In the "references" field of response, each element must be a URL or an image file name (e.g., "image_000.jpg"). They are marked with indices like [1], [2], etc in order of appearance in "references" list. "Analysis" must use these indices to cite the corresponding references.
|
| 89 |
+
|
| 90 |
+
4. There must be both visual and contextual evidences. For each evidence entry:
|
| 91 |
+
a. **Visual evidence**: cross‑check the original images against the satellite view.
|
| 92 |
+
- When citing original images (those before `image_{satellite_image_id}.jpg`), **do not** list them alone: each must be accompanied by at least one supporting satellite image, street‑view photo, or map URL in the same reference list.
|
| 93 |
+
- If confirmed, **rewrite and enrich** your analysis with additional visual details (textures, angles, shadows) and cite any new image or map references.
|
| 94 |
+
- If it can’t be verified, **remove** that entry entirely.
|
| 95 |
+
|
| 96 |
+
b. **Contextual evidence**: verify against the provided URLs.
|
| 97 |
+
- If confirmed, **rewrite and expand** your analysis with deeper context (dates, sources, related events) and cite any new supporting links.
|
| 98 |
+
- If it can’t be verified, **remove** that entry.
|
| 99 |
+
|
| 100 |
+
c. Analyze but **do not** need cite transcript and metadata.
|
| 101 |
+
|
| 102 |
+
5. All evidence must directly support the predicted latitude/longitude. Do not include analysis or references unrelated to verifying that specific location.
|
| 103 |
+
|
| 104 |
+
6. Do **not** include any metadata (EXIF, timestamps, filenames) as evidence.
|
| 105 |
+
|
| 106 |
+
Return only the JSON—no extra text, markdown, or comments.
|
| 107 |
+
"""
|
src/setup.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
base_path = Path(__file__).parent
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def setup(
|
| 10 |
+
local_path: Path,
|
| 11 |
+
repo_id: str,
|
| 12 |
+
filename: str,
|
| 13 |
+
subfolder: str | None = None,
|
| 14 |
+
repo_type: str | None = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
if not local_path.exists():
|
| 17 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
cached_path = hf_hub_download(
|
| 20 |
+
repo_id=repo_id,
|
| 21 |
+
subfolder=subfolder,
|
| 22 |
+
filename=filename,
|
| 23 |
+
repo_type=repo_type,
|
| 24 |
+
)
|
| 25 |
+
shutil.copy(cached_path, local_path)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
checkpoint_path = (
|
| 30 |
+
base_path / "data/checkpoints/mercator_finetune_weight.pth"
|
| 31 |
+
).resolve()
|
| 32 |
+
index_path = (base_path / "data/index/G3.index").resolve()
|
| 33 |
+
database_path = (base_path / "data/dataset/mp16/MP16_Pro_filtered.csv").resolve()
|
| 34 |
+
|
| 35 |
+
repo_id = "tduongvn/Checkpoints-ACMMM25"
|
| 36 |
+
|
| 37 |
+
setup(checkpoint_path, repo_id, "mercator_finetune_weight.pth")
|
| 38 |
+
setup(index_path, repo_id, "G3.index", "index")
|
| 39 |
+
setup(database_path, repo_id, "MP16_Pro_filtered.csv", "data/mp16")
|
src/utils.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import requests
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
# Set up logger
|
| 15 |
+
logger = logging.getLogger("uvicorn.error")
|
| 16 |
+
|
| 17 |
+
T = TypeVar("T")
|
| 18 |
+
|
| 19 |
+
NOMINATIM_URL = "https://nominatim.openstreetmap.org/search"
|
| 20 |
+
DEFAULT_USER_AGENT = "keyframe_extraction_app"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_gps_from_location(
|
| 24 |
+
location: str,
|
| 25 |
+
language: str = "en",
|
| 26 |
+
timeout: int = 10,
|
| 27 |
+
user_agent: str = DEFAULT_USER_AGENT,
|
| 28 |
+
) -> Tuple[Optional[float], Optional[float]]:
|
| 29 |
+
"""
|
| 30 |
+
Get GPS coordinates from a location string using Nominatim (OpenStreetMap).
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
location (str): Location string (e.g., city, address)
|
| 34 |
+
language (str): Language for results (default: 'en')
|
| 35 |
+
timeout (int): Request timeout in seconds (default: 10)
|
| 36 |
+
user_agent (str): User-Agent header (required by Nominatim)
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Tuple[Optional[float], Optional[float]]: (latitude, longitude), or (None, None) on failure
|
| 40 |
+
"""
|
| 41 |
+
if not isinstance(location, str) or not location.strip():
|
| 42 |
+
logger.warning("Invalid or empty location string provided.")
|
| 43 |
+
return (None, None)
|
| 44 |
+
|
| 45 |
+
params = {
|
| 46 |
+
"q": location.strip(),
|
| 47 |
+
"format": "json",
|
| 48 |
+
"addressdetails": 1,
|
| 49 |
+
"accept-language": language,
|
| 50 |
+
"limit": 1,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
headers = {
|
| 54 |
+
"User-Agent": user_agent,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
response = requests.get(
|
| 59 |
+
NOMINATIM_URL, params=params, headers=headers, timeout=timeout
|
| 60 |
+
)
|
| 61 |
+
response.raise_for_status()
|
| 62 |
+
data = response.json()
|
| 63 |
+
|
| 64 |
+
if not data:
|
| 65 |
+
logger.info(f"No results found for location: '{location}'")
|
| 66 |
+
return (None, None)
|
| 67 |
+
|
| 68 |
+
lat = float(data[0]["lat"])
|
| 69 |
+
lon = float(data[0]["lon"])
|
| 70 |
+
return (lat, lon)
|
| 71 |
+
|
| 72 |
+
except requests.RequestException as req_err:
|
| 73 |
+
logger.error(f"Request error while geocoding '{location}': {req_err}")
|
| 74 |
+
except (ValueError, KeyError, TypeError) as parse_err:
|
| 75 |
+
logger.error(
|
| 76 |
+
f"Failed to parse geocoding response for '{location}': {parse_err}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return (None, None)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def calculate_similarity_scores(
|
| 83 |
+
model: nn.Module,
|
| 84 |
+
device: torch.device,
|
| 85 |
+
predicted_coords: List[Tuple[float, float]],
|
| 86 |
+
image_dir: Union[str, Path] = "images",
|
| 87 |
+
) -> np.ndarray:
|
| 88 |
+
"""
|
| 89 |
+
Calculate similarity scores between images and predicted coordinates.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
rgb_images: List of PIL Images
|
| 93 |
+
predicted_coords: List of (lat, lon) tuples
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
np.ndarray: Average similarity scores across all images for each coordinate
|
| 97 |
+
"""
|
| 98 |
+
all_similarities = []
|
| 99 |
+
image_dir = Path(image_dir)
|
| 100 |
+
|
| 101 |
+
if not image_dir.exists():
|
| 102 |
+
raise ValueError(f"Image directory does not exist: {image_dir}")
|
| 103 |
+
|
| 104 |
+
for image_file in image_dir.glob("image_*.*"):
|
| 105 |
+
# Load image as PIL Image first
|
| 106 |
+
pil_image = Image.open(image_file).convert("RGB")
|
| 107 |
+
|
| 108 |
+
# Process the PIL image
|
| 109 |
+
image = model.vision_processor(images=pil_image, return_tensors="pt")[
|
| 110 |
+
"pixel_values"
|
| 111 |
+
].reshape(-1, 224, 224)
|
| 112 |
+
image = image.unsqueeze(0).to(device)
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
vision_output = model.vision_model(image)[1]
|
| 116 |
+
|
| 117 |
+
image_embeds = model.vision_projection_else_2(
|
| 118 |
+
model.vision_projection(vision_output)
|
| 119 |
+
)
|
| 120 |
+
image_embeds = image_embeds / image_embeds.norm(
|
| 121 |
+
p=2, dim=-1, keepdim=True
|
| 122 |
+
) # b, 768
|
| 123 |
+
|
| 124 |
+
# Process coordinates
|
| 125 |
+
gps_batch = torch.tensor(predicted_coords, dtype=torch.float32).to(device)
|
| 126 |
+
gps_input = gps_batch.clone().detach().unsqueeze(0) # Add batch dimension
|
| 127 |
+
b, c, _ = gps_input.shape
|
| 128 |
+
gps_input = gps_input.reshape(b * c, 2)
|
| 129 |
+
location_embeds = model.location_encoder(gps_input)
|
| 130 |
+
location_embeds = model.location_projection_else(
|
| 131 |
+
location_embeds.reshape(b * c, -1)
|
| 132 |
+
)
|
| 133 |
+
location_embeds = location_embeds / location_embeds.norm(
|
| 134 |
+
p=2, dim=-1, keepdim=True
|
| 135 |
+
)
|
| 136 |
+
location_embeds = location_embeds.reshape(b, c, -1) # b, c, 768
|
| 137 |
+
|
| 138 |
+
similarity = torch.matmul(
|
| 139 |
+
image_embeds.unsqueeze(1), location_embeds.permute(0, 2, 1)
|
| 140 |
+
) # b, 1, c
|
| 141 |
+
similarity = similarity.squeeze(1).cpu().detach().numpy()
|
| 142 |
+
all_similarities.append(similarity[0]) # Remove batch dimension
|
| 143 |
+
|
| 144 |
+
# Calculate average similarity across all images
|
| 145 |
+
avg_similarities = np.mean(all_similarities, axis=0)
|
| 146 |
+
return avg_similarities
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def is_retryable_error(error: Exception) -> bool:
|
| 150 |
+
"""
|
| 151 |
+
Determines if the given exception is retryable based on known patterns
|
| 152 |
+
and exception types.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
error (Exception): The exception to evaluate.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
bool: True if the error is considered retryable.
|
| 159 |
+
"""
|
| 160 |
+
error_str = str(error).lower()
|
| 161 |
+
|
| 162 |
+
# Known substrings that indicate retryable errors
|
| 163 |
+
retryable_patterns = [
|
| 164 |
+
"503",
|
| 165 |
+
"500",
|
| 166 |
+
"502",
|
| 167 |
+
"504",
|
| 168 |
+
"overloaded",
|
| 169 |
+
"unavailable",
|
| 170 |
+
"internal",
|
| 171 |
+
"disconnected",
|
| 172 |
+
"connection",
|
| 173 |
+
"timeout",
|
| 174 |
+
"remoteprotocolerror",
|
| 175 |
+
"remote protocol error",
|
| 176 |
+
"network",
|
| 177 |
+
"socket",
|
| 178 |
+
"ssl",
|
| 179 |
+
"tls",
|
| 180 |
+
"rate limit",
|
| 181 |
+
"too many requests",
|
| 182 |
+
"429",
|
| 183 |
+
"service unavailable",
|
| 184 |
+
"temporarily unavailable",
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
for pattern in retryable_patterns:
|
| 188 |
+
if pattern in error_str:
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
# Retryable exception types
|
| 192 |
+
retryable_types = {
|
| 193 |
+
"connectionerror",
|
| 194 |
+
"timeout",
|
| 195 |
+
"httperror",
|
| 196 |
+
"remoteclosederror",
|
| 197 |
+
"remoteprotocolerror",
|
| 198 |
+
"sslerror",
|
| 199 |
+
"tlserror",
|
| 200 |
+
"valueerror",
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
error_type = type(error).__name__.lower()
|
| 204 |
+
return error_type in retryable_types
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
async def handle_async_api_call_with_retry(
|
| 208 |
+
api_call_func: Callable[[], Any],
|
| 209 |
+
max_retries: int = 10,
|
| 210 |
+
base_delay: float = 2.0,
|
| 211 |
+
fallback_result: Optional[T] = None,
|
| 212 |
+
error_context: str = "API call",
|
| 213 |
+
) -> T:
|
| 214 |
+
"""
|
| 215 |
+
Executes an asynchronous API call with retry logic and exponential backoff.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
api_call_func (Callable): An async function that returns any type (T).
|
| 219 |
+
max_retries (int): Maximum retry attempts.
|
| 220 |
+
base_delay (float): Initial delay for backoff (doubles each retry).
|
| 221 |
+
fallback_result (Optional[T]): Optional result to return on failure.
|
| 222 |
+
error_context (str): Contextual info for logging.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
T: Result from the API call or fallback.
|
| 226 |
+
"""
|
| 227 |
+
for attempt in range(1, max_retries + 1):
|
| 228 |
+
try:
|
| 229 |
+
result = await api_call_func()
|
| 230 |
+
return result
|
| 231 |
+
|
| 232 |
+
except Exception as error:
|
| 233 |
+
is_last_attempt = attempt == max_retries
|
| 234 |
+
retryable = is_retryable_error(error)
|
| 235 |
+
|
| 236 |
+
logger.warning(
|
| 237 |
+
f"{error_context} failed (attempt {attempt}/{max_retries}): {error}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if retryable and not is_last_attempt:
|
| 241 |
+
delay = base_delay * (2 ** (attempt - 1))
|
| 242 |
+
logger.info(f"Retrying in {delay:.1f}s...")
|
| 243 |
+
await asyncio.sleep(delay)
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
if not retryable:
|
| 247 |
+
logger.error(f"Non-retryable error encountered: {error}")
|
| 248 |
+
elif is_last_attempt:
|
| 249 |
+
logger.error(f"Max retries reached for {error_context}. Giving up.")
|
| 250 |
+
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
if fallback_result is not None:
|
| 254 |
+
logger.warning(f"Returning fallback result for {error_context}")
|
| 255 |
+
return fallback_result
|
| 256 |
+
|
| 257 |
+
logger.error(f"No fallback result provided for {error_context}.")
|
| 258 |
+
raise RuntimeError(f"{error_context} failed with no result.")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def extract_and_parse_json(raw_text: str) -> Dict[str, Any]:
|
| 262 |
+
"""
|
| 263 |
+
Extract and parse the first JSON object found in raw_text.
|
| 264 |
+
Only returns a dict; falls back to {} on failure or if parsed value isn't a dict.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
raw_text (str): Raw text (e.g., from an LLM response)
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Dict[str, Any]: Parsed JSON dict, or {} if none valid is found.
|
| 271 |
+
"""
|
| 272 |
+
start = raw_text.find("{")
|
| 273 |
+
end = raw_text.rfind("}")
|
| 274 |
+
|
| 275 |
+
if start == -1 or end == -1 or end <= start:
|
| 276 |
+
logger.error("⚠️ No JSON object found. Snippet:", raw_text[:200])
|
| 277 |
+
return {}
|
| 278 |
+
|
| 279 |
+
snippet = raw_text[start : end + 1]
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
parsed = json.loads(snippet)
|
| 283 |
+
if isinstance(parsed, dict):
|
| 284 |
+
return parsed
|
| 285 |
+
logger.error("⚠️ JSON parsed but not a dict—got type:", type(parsed).__name__)
|
| 286 |
+
except json.JSONDecodeError as e:
|
| 287 |
+
logger.error("⚠️ JSON decoding error:", e)
|
| 288 |
+
|
| 289 |
+
return {}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def image_to_base64(image_path: Path) -> str:
|
| 293 |
+
if not image_path.is_file():
|
| 294 |
+
logger.error(f"No such image: {image_path}")
|
| 295 |
+
return ""
|
| 296 |
+
data = image_path.read_bytes()
|
| 297 |
+
return base64.b64encode(data).decode("utf-8")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def load_images_as_base64() -> Optional[list[str]]:
|
| 301 |
+
img_dir = Path(__file__).parent / "data" / "prompt_data" / "images"
|
| 302 |
+
|
| 303 |
+
if not img_dir.exists() or not any(img_dir.iterdir()):
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
+
base64_images: list[str] = []
|
| 307 |
+
for file in img_dir.iterdir():
|
| 308 |
+
if file.is_file() and file.suffix.lower() in [".png", ".jpg", ".jpeg", ".gif"]:
|
| 309 |
+
with open(file, "rb") as f:
|
| 310 |
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
| 311 |
+
base64_images.append(encoded)
|
| 312 |
+
return base64_images if base64_images else None
|