Spaces:
Sleeping
Sleeping
Commit ·
11e7313
0
Parent(s):
Deploy backend to Hugging Face
Browse files- .env +16 -0
- .gitignore +168 -0
- App/__init__.py +0 -0
- App/app.py +447 -0
- App/models.py +56 -0
- App/scheduler.py +195 -0
- Dockerfile +27 -0
- requirements.txt +19 -0
- requirements_scheduler.txt +13 -0
- setup.py +25 -0
- src/__init__.py +0 -0
- src/evaluation.py +59 -0
- src/preprocessing.py +29 -0
- src/training.py +206 -0
.env
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
DATABASE_URL=https://mrlyvrpxsumashqzcmhd.supabase.co
|
| 3 |
+
SUPABASE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im1ybHl2cnB4c3VtYXNocXpjbWhkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjE1NzQ5ODAsImV4cCI6MjA3NzE1MDk4MH0.HoD5V3nXSGnLFSbjcqveBn7LUZZPS4KUTEuM3eoQ2uQ
|
| 4 |
+
|
| 5 |
+
JWT_SECRET_KEY=Hello_buddy
|
| 6 |
+
SECRET_KEY=hello_buddy
|
| 7 |
+
CLIENT=http://localhost:3000
|
| 8 |
+
Qdrant_api_key=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9ZKSLbgV0_jyesC_fqhyuMHS2XRacAKa-jYeo01vCng
|
| 9 |
+
Qdrant_url=https://db450db1-2425-4e2e-b839-b1c585defee3.europe-west3-0.gcp.cloud.qdrant.io
|
| 10 |
+
Qdrant_Collection=FINDR
|
| 11 |
+
|
| 12 |
+
CLOUDINARY_API_SECRET=oaBBFHmbY7i6GNs8q_auVcwd5OM
|
| 13 |
+
CLOUDINARY_API_KEY=456896227428735
|
| 14 |
+
CLOUDINARY_CLIENT_NAME=dc728fl24
|
| 15 |
+
CLOUDINARY_URL=cloudinary://456896227428735:oaBBFHmbY7i6GNs8q_auVcwd5OM@dc728fl24
|
| 16 |
+
scheduler=https://findr-ai-scheduler.onrender.com
|
.gitignore
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
/get-pip.py
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
cover/
|
| 54 |
+
|
| 55 |
+
# Translations
|
| 56 |
+
*.mo
|
| 57 |
+
*.pot
|
| 58 |
+
|
| 59 |
+
# Django stuff:
|
| 60 |
+
*.log
|
| 61 |
+
local_settings.py
|
| 62 |
+
db.sqlite3
|
| 63 |
+
db.sqlite3-journal
|
| 64 |
+
|
| 65 |
+
# Flask stuff:
|
| 66 |
+
instance/
|
| 67 |
+
.webassets-cache
|
| 68 |
+
|
| 69 |
+
# Scrapy stuff:
|
| 70 |
+
.scrapy
|
| 71 |
+
|
| 72 |
+
# Sphinx documentation
|
| 73 |
+
docs/_build/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
.pybuilder/
|
| 77 |
+
target/
|
| 78 |
+
|
| 79 |
+
# Jupyter Notebook
|
| 80 |
+
.ipynb_checkpoints
|
| 81 |
+
|
| 82 |
+
# IPython
|
| 83 |
+
profile_default/
|
| 84 |
+
ipython_config.py
|
| 85 |
+
|
| 86 |
+
# pyenv
|
| 87 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 89 |
+
# .python-version
|
| 90 |
+
|
| 91 |
+
# pipenv
|
| 92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 95 |
+
# install all needed dependencies.
|
| 96 |
+
#Pipfile.lock
|
| 97 |
+
|
| 98 |
+
# poetry
|
| 99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 101 |
+
# commonly ignored for libraries.
|
| 102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 103 |
+
#poetry.lock
|
| 104 |
+
|
| 105 |
+
# pdm
|
| 106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 107 |
+
#pdm.lock
|
| 108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 109 |
+
# in version control.
|
| 110 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 111 |
+
.pdm.toml
|
| 112 |
+
.pdm-python
|
| 113 |
+
.pdm-build/
|
| 114 |
+
|
| 115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 116 |
+
__pypackages__/
|
| 117 |
+
|
| 118 |
+
# Celery stuff
|
| 119 |
+
celerybeat-schedule
|
| 120 |
+
celerybeat.pid
|
| 121 |
+
|
| 122 |
+
# SageMath parsed files
|
| 123 |
+
*.sage.py
|
| 124 |
+
|
| 125 |
+
# Environments
|
| 126 |
+
Server/.env
|
| 127 |
+
Client/.env
|
| 128 |
+
.venv
|
| 129 |
+
env/
|
| 130 |
+
venv/
|
| 131 |
+
new_env/
|
| 132 |
+
ENV/
|
| 133 |
+
env.bak/
|
| 134 |
+
venv.bak/
|
| 135 |
+
|
| 136 |
+
# Spyder project settings
|
| 137 |
+
.spyderproject
|
| 138 |
+
.spyproject
|
| 139 |
+
|
| 140 |
+
# Rope project settings
|
| 141 |
+
.ropeproject
|
| 142 |
+
|
| 143 |
+
# mkdocs documentation
|
| 144 |
+
/site
|
| 145 |
+
|
| 146 |
+
# mypy
|
| 147 |
+
.mypy_cache/
|
| 148 |
+
.dmypy.json
|
| 149 |
+
dmypy.json
|
| 150 |
+
|
| 151 |
+
# Pyre type checker
|
| 152 |
+
.pyre/
|
| 153 |
+
|
| 154 |
+
# pytype static type analyzer
|
| 155 |
+
.pytype/
|
| 156 |
+
|
| 157 |
+
# Cython debug symbols
|
| 158 |
+
cython_debug/
|
| 159 |
+
/datasets
|
| 160 |
+
/.dockerignore
|
| 161 |
+
/start.sh
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
App/__init__.py
ADDED
|
File without changes
|
App/app.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from flask import Flask,request,jsonify,make_response
|
| 3 |
+
from flask_bcrypt import Bcrypt
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from supabase import create_client
|
| 6 |
+
from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
|
| 7 |
+
from App.models import User,LostItem,FoundItem,Match
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from flask_cors import CORS
|
| 10 |
+
from src.training import encode_img_and_text
|
| 11 |
+
from qdrant_client import QdrantClient
|
| 12 |
+
from qdrant_client.http import models
|
| 13 |
+
import cloudinary
|
| 14 |
+
from cloudinary import uploader
|
| 15 |
+
import warnings
|
| 16 |
+
import base64
|
| 17 |
+
from io import BytesIO
|
| 18 |
+
import threading
|
| 19 |
+
from datetime import timedelta
|
| 20 |
+
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
load_dotenv()
|
| 24 |
+
|
| 25 |
+
app = Flask(__name__)
|
| 26 |
+
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
|
| 27 |
+
app.config["SECRET_KEY"] = os.getenv("SECRET_KEY")
|
| 28 |
+
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY")
|
| 29 |
+
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
| 30 |
+
app.config["JWT_ACCESS_COOKIE_NAME"] = "Token"
|
| 31 |
+
app.config["JWT_COOKIE_SAMESITE"] = "None"
|
| 32 |
+
app.config["JWT_COOKIE_SECURE"] = False
|
| 33 |
+
app.config["JWT_COOKIE_DOMAIN"] = ".localhost"
|
| 34 |
+
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(days=3)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
qdrant=QdrantClient(
|
| 38 |
+
url=os.getenv("Qdrant_url"),
|
| 39 |
+
api_key=os.getenv("Qdrant_api_key"),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
cloudinary.config(
|
| 43 |
+
cloud_name=os.getenv("CLOUDINARY_CLIENT_NAME"),
|
| 44 |
+
api_key=os.getenv("CLOUDINARY_API_KEY"),
|
| 45 |
+
api_secret=os.getenv("CLOUDINARY_API_SECRET"),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
db = create_client(os.getenv("DATABASE_URL"),os.getenv("SUPABASE_KEY"))
|
| 50 |
+
bcrypt=Bcrypt(app)
|
| 51 |
+
jwt=JWTManager(app)
|
| 52 |
+
CORS(app, supports_credentials=True, origins=[os.getenv("CLIENT")])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
print('Qdrant connected')
|
| 56 |
+
print("Posgres Connected")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def decode_jwt(fn):
|
| 61 |
+
@wraps(fn)
|
| 62 |
+
def wrapper(*args, **kwargs):
|
| 63 |
+
token=request.cookies.get("Token")
|
| 64 |
+
if not token:
|
| 65 |
+
print('no token found')
|
| 66 |
+
return jsonify({
|
| 67 |
+
"message":"No Token Found"
|
| 68 |
+
})
|
| 69 |
+
else:
|
| 70 |
+
id=decode_token(token)["sub"]
|
| 71 |
+
print(id)
|
| 72 |
+
return fn(user_id=id, *args, **kwargs)
|
| 73 |
+
return wrapper
|
| 74 |
+
|
| 75 |
+
class DotDict(dict):
|
| 76 |
+
def __getattr__(self, key):
|
| 77 |
+
try:
|
| 78 |
+
return self[key]
|
| 79 |
+
except KeyError:
|
| 80 |
+
raise AttributeError(f"No such attribute: {key}")
|
| 81 |
+
|
| 82 |
+
__setattr__ = dict.__setitem__
|
| 83 |
+
__delattr__ = dict.__delitem__
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def upload_img(img):
|
| 90 |
+
return uploader.upload(img, resource_type="image")["secure_url"]
|
| 91 |
+
|
| 92 |
+
@app.route('/register',methods=["POST"])
|
| 93 |
+
def register():
|
| 94 |
+
try:
|
| 95 |
+
print("received")
|
| 96 |
+
data=request.get_json()
|
| 97 |
+
first_name=data['firstName']
|
| 98 |
+
last_name=data['lastName']
|
| 99 |
+
email=data['email']
|
| 100 |
+
phone=data['phone']
|
| 101 |
+
password=data['password']
|
| 102 |
+
if db.table("users").select("*").eq("email",email).limit(1).execute().data:
|
| 103 |
+
|
| 104 |
+
return jsonify({
|
| 105 |
+
"success":False,
|
| 106 |
+
"error":"User Already Exist"
|
| 107 |
+
})
|
| 108 |
+
hashing=bcrypt.generate_password_hash(password).decode("utf-8")
|
| 109 |
+
new_user=db.table("users").insert({"first_name":first_name,"last_name":last_name,"email":email,"phone":phone,"password":hashing}).execute().data[0]
|
| 110 |
+
|
| 111 |
+
token=create_access_token(identity=str(new_user.get("id")), expires_delta=timedelta(days=3))
|
| 112 |
+
res=make_response({
|
| 113 |
+
"success":True,
|
| 114 |
+
"message":"User Registered Successfully",
|
| 115 |
+
|
| 116 |
+
})
|
| 117 |
+
res.set_cookie("Token",token,httponly=True,secure=True,samesite="None",max_age=259200,domain=".localhost")
|
| 118 |
+
return res,200
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(e)
|
| 121 |
+
return jsonify({
|
| 122 |
+
"sucsess":False,
|
| 123 |
+
"message":"Internal Server Error"
|
| 124 |
+
}),400
|
| 125 |
+
|
| 126 |
+
@app.route("/login",methods=["POST"])
|
| 127 |
+
def login():
|
| 128 |
+
try:
|
| 129 |
+
data=request.get_json()
|
| 130 |
+
print(data)
|
| 131 |
+
email=data['email']
|
| 132 |
+
password=data['password']
|
| 133 |
+
user= db.table("users").select("*").eq("email",email).limit(1).execute().data[0]
|
| 134 |
+
if not user or not bcrypt.check_password_hash(user.get("password"),password):
|
| 135 |
+
return jsonify({
|
| 136 |
+
"sucsess":False,
|
| 137 |
+
"message":"No User Found"
|
| 138 |
+
}),200
|
| 139 |
+
token=create_access_token(identity=str(user.get("id")), expires_delta=timedelta(days=3))
|
| 140 |
+
res=make_response({
|
| 141 |
+
"success":True,
|
| 142 |
+
"message":"User Login Successfully",
|
| 143 |
+
|
| 144 |
+
})
|
| 145 |
+
res.set_cookie("Token",token,httponly=True,secure=True,samesite="None",max_age=259200,domain=".localhost")
|
| 146 |
+
return res,200
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(e)
|
| 149 |
+
return jsonify({
|
| 150 |
+
"sucsess":False,
|
| 151 |
+
"message":"Internal Server Error"
|
| 152 |
+
}),400
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.route('/logout',methods=["GET"])
|
| 156 |
+
def logout():
|
| 157 |
+
res=make_response({
|
| 158 |
+
"success":True,
|
| 159 |
+
"message":"Logout Successfully"
|
| 160 |
+
})
|
| 161 |
+
unset_jwt_cookies(res)
|
| 162 |
+
return res,200
|
| 163 |
+
|
| 164 |
+
@app.route('/get_user',methods=['GET'])
|
| 165 |
+
@decode_jwt
|
| 166 |
+
def get_user(user_id):
|
| 167 |
+
user=user= db.table("users").select("*").eq("id",user_id).limit(1).execute().data[0]
|
| 168 |
+
return jsonify({
|
| 169 |
+
"success":True,
|
| 170 |
+
"user":{"first_name": user.first_name,
|
| 171 |
+
"last_name": user.last_name,
|
| 172 |
+
"email": user.email,
|
| 173 |
+
"phone_number": user.phone}
|
| 174 |
+
}), 200
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@app.route('/lostItem',methods=['POST'])
|
| 178 |
+
@decode_jwt
|
| 179 |
+
def lostItem(user_id):
|
| 180 |
+
print('search start')
|
| 181 |
+
try:
|
| 182 |
+
imgs_data=request.files.getlist('item')
|
| 183 |
+
img_urls=[upload_img(img) for img in imgs_data]
|
| 184 |
+
description=request.form.get('description')
|
| 185 |
+
print(imgs_data, description)
|
| 186 |
+
vector=encode_img_and_text(imgs_data,description)
|
| 187 |
+
lastSeenLocation=request.form.get('lastSeenLocation')
|
| 188 |
+
dateTimeLost=request.form.get('dateTimeLost')
|
| 189 |
+
name=request.form.get('name')
|
| 190 |
+
email=request.form.get('email')
|
| 191 |
+
phone=request.form.get('phone')
|
| 192 |
+
reward=request.form.get('reward')
|
| 193 |
+
additionalNotes=request.form.get('additionalNotes')
|
| 194 |
+
item=db.table("lostItem").insert({"user_id":user_id,"name":name,"email":email,"phone":phone,"description":description,"lastSeenLocation":lastSeenLocation,"dateTimeLost":dateTimeLost,"reward":reward,"additionalNotes":additionalNotes,"image_url": img_urls}).execute().data[0]
|
| 195 |
+
print('db save', len(vector))
|
| 196 |
+
|
| 197 |
+
collections = qdrant.get_collections().collections
|
| 198 |
+
existing_names = [c.name for c in collections]
|
| 199 |
+
|
| 200 |
+
if "lost_items" not in existing_names:
|
| 201 |
+
qdrant.create_collection(
|
| 202 |
+
collection_name="lost_items",
|
| 203 |
+
vectors_config=models.VectorParams(
|
| 204 |
+
size=512,
|
| 205 |
+
distance=models.Distance.COSINE
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
qdrant.upsert(
|
| 210 |
+
collection_name="lost_items",
|
| 211 |
+
points=[
|
| 212 |
+
models.PointStruct(id=item.get("id"), vector=vector, payload={"description":description,"place_lost":lastSeenLocation,"status" : "active"})
|
| 213 |
+
],
|
| 214 |
+
)
|
| 215 |
+
print('vector save')
|
| 216 |
+
return jsonify({
|
| 217 |
+
"success":True
|
| 218 |
+
}),200
|
| 219 |
+
except Exception as e:
|
| 220 |
+
import traceback
|
| 221 |
+
traceback.print_exc()
|
| 222 |
+
return jsonify({
|
| 223 |
+
"success":False
|
| 224 |
+
}),400
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@app.route('/foundItem',methods=['POST'])
|
| 228 |
+
@decode_jwt
|
| 229 |
+
def foundItem(user_id):
|
| 230 |
+
print('search start')
|
| 231 |
+
try:
|
| 232 |
+
imgs_data=request.files.getlist('item')
|
| 233 |
+
img_urls=[upload_img(img) for img in imgs_data]
|
| 234 |
+
description=request.form.get('description')
|
| 235 |
+
print(imgs_data, description)
|
| 236 |
+
vector=encode_img_and_text(imgs_data,description)
|
| 237 |
+
found_near=request.form.get('found_near')
|
| 238 |
+
name=request.form.get('name')
|
| 239 |
+
email=request.form.get('email')
|
| 240 |
+
phone=request.form.get('phone')
|
| 241 |
+
item=db.table("foundItem").insert({"user_id":user_id,"name":name,"email":email,"phone":phone,"description":description,"found_near":found_near,"image_url": img_urls}).execute().data[0]
|
| 242 |
+
|
| 243 |
+
print('db save', len(vector))
|
| 244 |
+
|
| 245 |
+
collections = qdrant.get_collections().collections
|
| 246 |
+
existing_names = [c.name for c in collections]
|
| 247 |
+
|
| 248 |
+
if "found_items" not in existing_names:
|
| 249 |
+
qdrant.create_collection(
|
| 250 |
+
collection_name="found_items",
|
| 251 |
+
vectors_config=models.VectorParams(
|
| 252 |
+
size=512,
|
| 253 |
+
distance=models.Distance.COSINE
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
qdrant.upsert(
|
| 258 |
+
collection_name="found_items",
|
| 259 |
+
points=[
|
| 260 |
+
models.PointStruct(id=item.get("id"), vector=vector, payload={"description":description,"place_found":found_near,"status" : "active"})
|
| 261 |
+
],
|
| 262 |
+
)
|
| 263 |
+
print('vector save')
|
| 264 |
+
return jsonify({
|
| 265 |
+
"success":True
|
| 266 |
+
}),200
|
| 267 |
+
except Exception as e:
|
| 268 |
+
import traceback
|
| 269 |
+
traceback.print_exc()
|
| 270 |
+
return jsonify({
|
| 271 |
+
"success":False
|
| 272 |
+
}),200
|
| 273 |
+
|
| 274 |
+
@app.route('/allLostItems',methods=['GET'])
|
| 275 |
+
@decode_jwt
|
| 276 |
+
def allLostItems(user_id):
|
| 277 |
+
|
| 278 |
+
rows = db.table("lostItem").select("*").eq("user_id",user_id).execute().data
|
| 279 |
+
items = [DotDict(r) for r in rows]
|
| 280 |
+
output = []
|
| 281 |
+
for item in items[0]:
|
| 282 |
+
output.append({
|
| 283 |
+
"id": item.id,
|
| 284 |
+
"name": item.name,
|
| 285 |
+
"email": item.email,
|
| 286 |
+
"phone": item.phone,
|
| 287 |
+
"description": item.description,
|
| 288 |
+
"lastSeenLocation": item.lastSeenLocation,
|
| 289 |
+
"dateTimeLost": item.dateTimeLost,
|
| 290 |
+
"reward": item.reward,
|
| 291 |
+
"additionalNotes": item.additionalNotes,
|
| 292 |
+
"image_url": item.image_url,
|
| 293 |
+
"status": item.status,
|
| 294 |
+
"created_at": item.created_at
|
| 295 |
+
})
|
| 296 |
+
return jsonify({
|
| 297 |
+
"success":True,
|
| 298 |
+
"lostItems":output
|
| 299 |
+
}),200
|
| 300 |
+
|
| 301 |
+
@app.route('/matchLost/<lost_id>',methods=['GET'])
|
| 302 |
+
def matchLost(lost_id):
|
| 303 |
+
items = db.table("lostItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
|
| 304 |
+
found_items=[]
|
| 305 |
+
if not items.get("found_items"):
|
| 306 |
+
return jsonify({
|
| 307 |
+
"success":False,
|
| 308 |
+
"message":"No Lost Item Found"
|
| 309 |
+
}),400
|
| 310 |
+
for ids in items.get("found_items"):
|
| 311 |
+
found_item=db.table("foundItem").select("*").eq("id",int(ids)).limit(1).execute().data[0]
|
| 312 |
+
if found_item:
|
| 313 |
+
found_items.append({
|
| 314 |
+
"id": found_item.get("id"),
|
| 315 |
+
"name": found_item.get("name"),
|
| 316 |
+
"email": found_item.get("email"),
|
| 317 |
+
"phone": found_item.get("phone"),
|
| 318 |
+
"description": found_item.get("description"),
|
| 319 |
+
"found_near": found_item.get("found_near"),
|
| 320 |
+
"image_url": found_item.get("image_url"),
|
| 321 |
+
"status": found_item.get("status"),
|
| 322 |
+
"created_at": found_item.get("created_at")
|
| 323 |
+
})
|
| 324 |
+
|
| 325 |
+
return jsonify({
|
| 326 |
+
"success":True,
|
| 327 |
+
"foundItems": found_items
|
| 328 |
+
}),200
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@app.route('/matchFound/<lost_id>',methods=['GET'])
|
| 333 |
+
def matchFound(lost_id):
|
| 334 |
+
items = db.table("foundItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
|
| 335 |
+
|
| 336 |
+
found_items=[]
|
| 337 |
+
if not items.get("lost_items"):
|
| 338 |
+
return jsonify({
|
| 339 |
+
"success":False,
|
| 340 |
+
"message":"No Found Item Found"
|
| 341 |
+
}),200
|
| 342 |
+
for ids in items.get("lost_items"):
|
| 343 |
+
found_item=db.table("lostItem").select("*").eq("id",int(ids)).limit(1).execute().data[0]
|
| 344 |
+
if found_item:
|
| 345 |
+
found_items.append({
|
| 346 |
+
"id": found_item.get("id"),
|
| 347 |
+
"name": found_item.get("name"),
|
| 348 |
+
"email": found_item.get("email"),
|
| 349 |
+
"phone": found_item.get("phone"),
|
| 350 |
+
"description": found_item.get("description"),
|
| 351 |
+
"found_near": found_item.get("lastSeenLocation"),
|
| 352 |
+
"image_url": found_item.get("image_url"),
|
| 353 |
+
"status": found_item.get("status"),
|
| 354 |
+
"created_at": found_item.get("created_at")
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
return jsonify({
|
| 358 |
+
"success":True,
|
| 359 |
+
"foundItems": found_items
|
| 360 |
+
}),200
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@app.route('/lostMatchDetail/<lost_id>',methods=['GET'])
|
| 364 |
+
def lostMatchDetail(lost_id):
|
| 365 |
+
found_item = db.table("lostItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
|
| 366 |
+
found_items=[]
|
| 367 |
+
if not found_item:
|
| 368 |
+
return jsonify({
|
| 369 |
+
"success":False,
|
| 370 |
+
"message":"No Found Item Found"
|
| 371 |
+
}),400
|
| 372 |
+
|
| 373 |
+
found_items.append({
|
| 374 |
+
"id": found_item.get("id"),
|
| 375 |
+
"name": found_item.get("name"),
|
| 376 |
+
"email": found_item.get("email"),
|
| 377 |
+
"phone": found_item.get("phone"),
|
| 378 |
+
"description": found_item.get("description"),
|
| 379 |
+
"found_near": found_item.get("lastSeenLocation"),
|
| 380 |
+
"image_url": found_item.get("image_url"),
|
| 381 |
+
"status": found_item.get("status"),
|
| 382 |
+
"date_lost":found_item.get("dateTimeLost"),
|
| 383 |
+
"reward":found_item.get("reward"),
|
| 384 |
+
"additional_notes":found_item.get("additionalNotes"),
|
| 385 |
+
"created_at": found_item.get("created_at")
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
return jsonify({
|
| 390 |
+
"success":True,
|
| 391 |
+
"foundItems": found_items
|
| 392 |
+
}),200
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@app.route('/allFoundItems',methods=['GET'])
|
| 396 |
+
@decode_jwt
|
| 397 |
+
def allFoundItems(user_id):
|
| 398 |
+
items = db.table("foundItem").select("*").eq("user_id",user_id).execute().data[0]
|
| 399 |
+
output = []
|
| 400 |
+
for item in items:
|
| 401 |
+
output.append({
|
| 402 |
+
"id": item.get("id"),
|
| 403 |
+
"name": item.get("name"),
|
| 404 |
+
"email": item.get("email"),
|
| 405 |
+
"phone": item.get("phone"),
|
| 406 |
+
"description": item.get("description"),
|
| 407 |
+
"found_near": item.get("found_near"),
|
| 408 |
+
"image_url": item.get("image_url"),
|
| 409 |
+
"status": item.get("status"),
|
| 410 |
+
"created_at": item.get("created_at")
|
| 411 |
+
})
|
| 412 |
+
return jsonify({
|
| 413 |
+
"success":True,
|
| 414 |
+
"foundItems":output
|
| 415 |
+
}),200
|
| 416 |
+
|
| 417 |
+
@app.route('/foundMatchDetail/<lost_id>',methods=['GET'])
|
| 418 |
+
def foundMatchDetail(lost_id):
|
| 419 |
+
found_item = db.table("foundItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
|
| 420 |
+
print(lost_id)
|
| 421 |
+
found_items=[]
|
| 422 |
+
if not found_item:
|
| 423 |
+
return jsonify({
|
| 424 |
+
"success":False,
|
| 425 |
+
"message":"No Found Item Found"
|
| 426 |
+
}),400
|
| 427 |
+
|
| 428 |
+
found_items.append({
|
| 429 |
+
"id": found_item.get("id"),
|
| 430 |
+
"name": found_item.get("name"),
|
| 431 |
+
"email": found_item.get("email"),
|
| 432 |
+
"phone": found_item.get("phone"),
|
| 433 |
+
"description": found_item.get("description"),
|
| 434 |
+
"found_near": found_item.get("found_near"),
|
| 435 |
+
"image_url": found_item.get("image_url"),
|
| 436 |
+
"status": found_item.get("status"),
|
| 437 |
+
"created_at": found_item.get("created_at")
|
| 438 |
+
})
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
return jsonify({
|
| 442 |
+
"success":True,
|
| 443 |
+
"foundItems": found_items
|
| 444 |
+
}),200
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
app.run(debug=True,port=8000)
|
App/models.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask_sqlalchemy import SQLAlchemy
|
| 2 |
+
import uuid
|
| 3 |
+
from sqlalchemy.dialects.postgresql import JSON,ARRAY
|
| 4 |
+
from sqlalchemy.ext.mutable import MutableList
|
| 5 |
+
|
| 6 |
+
db = SQLAlchemy()
|
| 7 |
+
|
| 8 |
+
class User(db.Model):
|
| 9 |
+
__tablename__ = "users"
|
| 10 |
+
|
| 11 |
+
id = db.Column(db.Integer, primary_key=True)
|
| 12 |
+
first_name = db.Column(db.String(50), nullable=False)
|
| 13 |
+
last_name = db.Column(db.String(50), nullable=False)
|
| 14 |
+
email = db.Column(db.String(100), unique=True, nullable=False)
|
| 15 |
+
phone=db.Column(db.String(15), nullable=False)
|
| 16 |
+
password = db.Column(db.String(200), nullable=False)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LostItem(db.Model):
|
| 20 |
+
__tablename__="lostItem"
|
| 21 |
+
id = db.Column(db.Integer, primary_key=True, index=True)
|
| 22 |
+
user_id = db.Column(db.Integer, db.ForeignKey("users.id"),nullable=False)
|
| 23 |
+
name=db.Column(db.String,nullable=False)
|
| 24 |
+
email=db.Column(db.String,nullable=False)
|
| 25 |
+
phone=db.Column(db.String,nullable=False)
|
| 26 |
+
description = db.Column(db.Text, nullable=False)
|
| 27 |
+
lastSeenLocation = db.Column(db.Text, nullable=False)
|
| 28 |
+
dateTimeLost = db.Column(db.Text, nullable=False)
|
| 29 |
+
reward = db.Column(db.Text)
|
| 30 |
+
additionalNotes = db.Column(db.Text)
|
| 31 |
+
image_url = db.Column(JSON, nullable=False)
|
| 32 |
+
status=db.Column(db.String,nullable=False,default='active')
|
| 33 |
+
found_items=db.Column(MutableList.as_mutable(ARRAY(db.String)),nullable=False)
|
| 34 |
+
created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
|
| 35 |
+
|
| 36 |
+
class FoundItem(db.Model):
|
| 37 |
+
__tablename__="foundItem"
|
| 38 |
+
id = db.Column(db.Integer, primary_key=True, index=True)
|
| 39 |
+
user_id = db.Column(db.Integer, db.ForeignKey("users.id"),nullable=False)
|
| 40 |
+
name=db.Column(db.String,nullable=False)
|
| 41 |
+
email=db.Column(db.String,nullable=False)
|
| 42 |
+
phone=db.Column(db.String,nullable=False)
|
| 43 |
+
description = db.Column(db.Text, nullable=False)
|
| 44 |
+
found_near= db.Column(db.Text, nullable=False)
|
| 45 |
+
image_url = db.Column(JSON, nullable=False)
|
| 46 |
+
status=db.Column(db.String,nullable=False,default='active')
|
| 47 |
+
lost_items=db.Column(MutableList.as_mutable(ARRAY(db.String)),nullable=False)
|
| 48 |
+
created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
|
| 49 |
+
|
| 50 |
+
class Match(db.Model):
|
| 51 |
+
__tablename__="matches"
|
| 52 |
+
id = db.Column(db.Integer, primary_key=True, index=True)
|
| 53 |
+
lost_item_id = db.Column(db.Integer, db.ForeignKey("lostItem.id"),nullable=False)
|
| 54 |
+
found_item_id = db.Column(db.Integer, db.ForeignKey("foundItem.id"),nullable=False)
|
| 55 |
+
confidence_score=db.Column(db.Float,nullable=False)
|
| 56 |
+
created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
|
App/scheduler.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from flask import Flask,request,jsonify,make_response
|
| 3 |
+
from flask_bcrypt import Bcrypt
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from flask_cors import CORS
|
| 8 |
+
from qdrant_client import QdrantClient
|
| 9 |
+
from qdrant_client.http import models
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from qdrant_client.models import PayloadSchemaType
|
| 12 |
+
from supabase import create_client
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
app = Flask(__name__)
|
| 17 |
+
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
|
| 18 |
+
|
| 19 |
+
db = create_client(os.getenv("DATABASE_URL"),os.getenv("SUPABASE_KEY"))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
qdrant=QdrantClient(
|
| 23 |
+
url=os.getenv("Qdrant_url"),
|
| 24 |
+
api_key=os.getenv("Qdrant_api_key"),
|
| 25 |
+
)
|
| 26 |
+
qdrant.create_payload_index(
|
| 27 |
+
collection_name="found_items",
|
| 28 |
+
field_name="status",
|
| 29 |
+
field_schema=PayloadSchemaType.KEYWORD
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
DEFAULT_MIN_AGE_MINUTES = 60
|
| 33 |
+
DEFAULT_CONFIDENCE = 0.78
|
| 34 |
+
DEFAULT_TOP_K = 5
|
| 35 |
+
|
| 36 |
+
CORS(app, supports_credentials=True, origins=[os.getenv("CLIENT")])
|
| 37 |
+
|
| 38 |
+
print('Qdrant connected')
|
| 39 |
+
print("Posgres Connected")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def iso_now_minus(minutes):
|
| 43 |
+
return (datetime.utcnow() - timedelta(minutes=minutes)).isoformat()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@app.route("/admin/match-active-lost", methods=["POST"])
|
| 47 |
+
def match_active_lost():
|
| 48 |
+
body = request.get_json(silent=True) or {}
|
| 49 |
+
min_age_minutes = int(body.get("min_age_minutes", DEFAULT_MIN_AGE_MINUTES))
|
| 50 |
+
confidence_threshold = float(body.get("confidence_threshold", DEFAULT_CONFIDENCE))
|
| 51 |
+
top_k = int(body.get("top_k", DEFAULT_TOP_K))
|
| 52 |
+
|
| 53 |
+
cutoff = iso_now_minus(min_age_minutes)
|
| 54 |
+
|
| 55 |
+
lost_rows = (db.table("lostItem").select("*").eq("status","active").execute())
|
| 56 |
+
lost_items = [
|
| 57 |
+
{
|
| 58 |
+
"id": r.get("id"),
|
| 59 |
+
"user_id": r.get("user_id"),
|
| 60 |
+
"created_at": r.get("created_at")
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
for r in lost_rows.data
|
| 65 |
+
]
|
| 66 |
+
if not lost_items:
|
| 67 |
+
return jsonify({"message": "No eligible lost items found", "checked": 0}), 200
|
| 68 |
+
|
| 69 |
+
created_matches = []
|
| 70 |
+
errors = []
|
| 71 |
+
|
| 72 |
+
for lost in lost_items:
|
| 73 |
+
lost_id=int(lost["id"])
|
| 74 |
+
lost_user_id=str(lost["user_id"])
|
| 75 |
+
try:
|
| 76 |
+
point=qdrant.retrieve(collection_name="lost_items", ids=[lost_id], with_vectors=True)
|
| 77 |
+
if not point:
|
| 78 |
+
errors.append({"lost_id": lost_id, "error": "No vector found in Qdrant for this lost item"})
|
| 79 |
+
continue
|
| 80 |
+
lost_vector = point[0].vector
|
| 81 |
+
except Exception as e:
|
| 82 |
+
errors.append({"lost_id": lost_id, "error": f"Qdrant get_point failed: {str(e)}"})
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
qfilter = models.Filter(
|
| 86 |
+
must=[models.FieldCondition(key="status", match=models.MatchValue(value="active"))]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
results = qdrant.search(
|
| 91 |
+
collection_name="found_items",
|
| 92 |
+
query_vector=lost_vector,
|
| 93 |
+
limit=top_k,
|
| 94 |
+
query_filter=qfilter
|
| 95 |
+
)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
errors.append({"lost_id": lost_id, "error": f"Qdrant search failed: {str(e)}"})
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
for r in results:
|
| 101 |
+
score = float(r.score) if r.score is not None else None
|
| 102 |
+
if score is None:
|
| 103 |
+
continue
|
| 104 |
+
if score < confidence_threshold:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
found_point_payload = r.payload or {}
|
| 108 |
+
found_supabase_id = found_point_payload.get("id") or r.id
|
| 109 |
+
|
| 110 |
+
if lost_user_id == str(found_point_payload.get("user_id")):
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
existing_match=db.table("matches").select("lost_item_id","found_item_id").eq("found_item_id", found_supabase_id).eq("lost_item_id", lost_id).limit(1).execute()
|
| 115 |
+
if existing_match:
|
| 116 |
+
continue
|
| 117 |
+
except Exception as e:
|
| 118 |
+
errors.append({"lost_id": lost_id, "error": f"Database query failed: {str(e)}"})
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
match_record = db.table("matches").insert({
|
| 124 |
+
"lost_item_id": lost_id,
|
| 125 |
+
"found_item_id": found_supabase_id,
|
| 126 |
+
"confidence_score": score
|
| 127 |
+
}).execute()
|
| 128 |
+
created_matches.append({"lost_id": lost_id, "found_id": found_supabase_id, "similarity": score})
|
| 129 |
+
except Exception as e:
|
| 130 |
+
errors.append({"lost_id": lost_id, "found_id": found_supabase_id, "error": f"Supabase insert failed: {str(e)}"})
|
| 131 |
+
|
| 132 |
+
matches=(db.table("matches").select("id","lost_item_id","found_item_id","confidence_score").execute())
|
| 133 |
+
|
| 134 |
+
match_items = [
|
| 135 |
+
{
|
| 136 |
+
"id": r.get("id"),
|
| 137 |
+
"lost_id": int(r.get("lost_item_id")),
|
| 138 |
+
"found_id": int(r.get("found_item_id")),
|
| 139 |
+
"confidence_score": r.get("confidence_score")
|
| 140 |
+
}
|
| 141 |
+
for r in matches.data
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
for match in match_items:
|
| 145 |
+
already_exist_found=(db.table("lostItem").select("*").eq("user_id", match["lost_id"]).contains("found_items", [str(match["found_id"])]).limit(1).execute()
|
| 146 |
+
)
|
| 147 |
+
if already_exist_found:
|
| 148 |
+
continue
|
| 149 |
+
else:
|
| 150 |
+
res=(db.table("lostItem").select("*").eq("id", match["lost_id"]).limit(1).execute()
|
| 151 |
+
)
|
| 152 |
+
if not res.data:
|
| 153 |
+
found_item = None
|
| 154 |
+
else:
|
| 155 |
+
found_item = res.data[0]
|
| 156 |
+
|
| 157 |
+
current_lost_items = found_item.get("found_items") or []
|
| 158 |
+
|
| 159 |
+
if match["found_id"] not in current_lost_items:
|
| 160 |
+
current_lost_items.append(match["found_id"])
|
| 161 |
+
|
| 162 |
+
db.table("lostItem").update({"found_items": current_lost_items }).eq("id", match["lost_id"]).execute()
|
| 163 |
+
|
| 164 |
+
already_exist_found=( db.table("foundItem").select("*").eq("user_id", match["found_id"]).contains("lost_items", [str(match["lost_id"])]).limit(1).execute()
|
| 165 |
+
)
|
| 166 |
+
if already_exist_found:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
else:
|
| 170 |
+
res=(db.table("foundItem").select("lost_items").eq("id", match["found_id"]).limit(1).execute()
|
| 171 |
+
)
|
| 172 |
+
if not res.data:
|
| 173 |
+
found_item = None
|
| 174 |
+
else:
|
| 175 |
+
found_item = res.data[0]
|
| 176 |
+
|
| 177 |
+
current_lost_items = found_item.get("lost_items") or []
|
| 178 |
+
|
| 179 |
+
if match["lost_id"] not in current_lost_items:
|
| 180 |
+
current_lost_items.append(match["lost_id"])
|
| 181 |
+
|
| 182 |
+
db.table("foundItem").update({"lost_items": current_lost_items }).eq("id", match["found_id"]).execute()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
return jsonify({
|
| 186 |
+
"checked_lost_count": len(lost_rows.data),
|
| 187 |
+
"created_matches_count": len(created_matches),
|
| 188 |
+
"created_matches": created_matches,
|
| 189 |
+
"errors": errors
|
| 190 |
+
}), 200
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
app.run(host="0.0.0.0",port=5000)
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /Server
|
| 4 |
+
|
| 5 |
+
# 🔴 REQUIRED system libraries for pyarrow / datasets
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
cmake \
|
| 11 |
+
curl \
|
| 12 |
+
libglib2.0-0 \
|
| 13 |
+
libsm6 \
|
| 14 |
+
libxext6 \
|
| 15 |
+
libxrender-dev \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
|
| 20 |
+
# 🔴 Upgrade tooling and install deps
|
| 21 |
+
RUN python -m pip install --upgrade pip setuptools wheel \
|
| 22 |
+
&& python -m pip install -r requirements.txt --no-cache-dir
|
| 23 |
+
|
| 24 |
+
COPY . .
|
| 25 |
+
|
| 26 |
+
CMD ["python", "-m", "App.app"]
|
| 27 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
pillow
|
| 5 |
+
tqdm
|
| 6 |
+
transformers
|
| 7 |
+
huggingface_hub
|
| 8 |
+
flask
|
| 9 |
+
flask-cors
|
| 10 |
+
flask-sqlalchemy
|
| 11 |
+
flask-bcrypt
|
| 12 |
+
flask-jwt-extended
|
| 13 |
+
psycopg2-binary
|
| 14 |
+
python-dotenv
|
| 15 |
+
cloudinary
|
| 16 |
+
qdrant-client
|
| 17 |
+
datasets
|
| 18 |
+
open-clip-torch
|
| 19 |
+
supabase
|
requirements_scheduler.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask
|
| 2 |
+
flask-cors
|
| 3 |
+
flask-sqlalchemy
|
| 4 |
+
flask-bcrypt
|
| 5 |
+
flask-jwt-extended
|
| 6 |
+
psycopg2-binary
|
| 7 |
+
python-dotenv
|
| 8 |
+
cloudinary
|
| 9 |
+
qdrant-client
|
| 10 |
+
numpy
|
| 11 |
+
pillow
|
| 12 |
+
tqdm
|
| 13 |
+
supabase
|
setup.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup,find_packages
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
HYPHEN_E_DOT='-e .'
|
| 5 |
+
|
| 6 |
+
def get_requiremnets(filename : str)->List[str]:
|
| 7 |
+
requirements = []
|
| 8 |
+
with open(filename,'r') as f:
|
| 9 |
+
requirements=f.readlines()
|
| 10 |
+
requirements = [req.replace('\n','') for req in requirements]
|
| 11 |
+
|
| 12 |
+
if HYPHEN_E_DOT in requirements:
|
| 13 |
+
requirements.remove(HYPHEN_E_DOT)
|
| 14 |
+
|
| 15 |
+
return requirements
|
| 16 |
+
|
| 17 |
+
setup (
|
| 18 |
+
name = "FINDR",
|
| 19 |
+
version = "0.0.1",
|
| 20 |
+
packages = find_packages(),
|
| 21 |
+
author = "Prashant",
|
| 22 |
+
author_email = "prashant.goyal2002@gmail.com",
|
| 23 |
+
description = "a simple project to predict performance of student",
|
| 24 |
+
install_requires = get_requiremnets('requirements.txt')
|
| 25 |
+
)
|
src/__init__.py
ADDED
|
File without changes
|
src/evaluation.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,json
|
| 2 |
+
from training import clip_dataset
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import open_clip
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 12 |
+
|
| 13 |
+
def collate(batch):
|
| 14 |
+
img,text=zip(*batch)
|
| 15 |
+
return torch.stack(img,0),torch.stack(text,0)
|
| 16 |
+
|
| 17 |
+
@torch.no_grad
|
| 18 |
+
def encode_img(model,processor,tokenizer,split,device):
|
| 19 |
+
ds=clip_dataset(split=split,processor=processor,tokenizer=tokenizer)
|
| 20 |
+
print('dataset Loaded')
|
| 21 |
+
dl=DataLoader(ds,batch_size=4,shuffle=False,num_workers=4,collate_fn=collate)
|
| 22 |
+
all_img,all_text=[],[]
|
| 23 |
+
for img,text in tqdm(dl,desc=f"Encode {split}"):
|
| 24 |
+
img=img.to(device)
|
| 25 |
+
text=text.to(device)
|
| 26 |
+
img_f=model.encode_image(img)
|
| 27 |
+
text_f=model.encode_text(text)
|
| 28 |
+
img_f=img_f/img_f.norm(keepdim=True,dim=-1)
|
| 29 |
+
text_f=text_f/text_f.norm(keepdim=True,dim=-1)
|
| 30 |
+
all_img.append(img_f.cpu())
|
| 31 |
+
all_text.append(text_f.cpu())
|
| 32 |
+
return torch.cat(all_img),torch.cat(all_text)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def gold_k(sims,k):
|
| 36 |
+
ranks = (-sims).argsort(axis=1)
|
| 37 |
+
hits = (ranks[:, :k] == np.arange(sims.shape[0])[:,None]).any(axis=1)
|
| 38 |
+
return hits.mean()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main(path='./model/clip/best.pt',arch='ViT-B-32', pretrained='openai'):
|
| 42 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
torch.cuda.empty_cache()
|
| 44 |
+
model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
|
| 45 |
+
tokenizer=open_clip.get_tokenizer(arch)
|
| 46 |
+
state=torch.load(path,map_location='cuda')['model']
|
| 47 |
+
model.load_state_dict(state, strict=False)
|
| 48 |
+
model.eval()
|
| 49 |
+
print('model loaded')
|
| 50 |
+
img_f,text_f=encode_img(model,processor=preprocess,tokenizer=tokenizer,split='test',device=device)
|
| 51 |
+
sim=(img_f@text_f.T).numpy()
|
| 52 |
+
g1=gold_k(sim,1)
|
| 53 |
+
g5=gold_k(sim,5)
|
| 54 |
+
g10=gold_k(sim,10)
|
| 55 |
+
print(f"Image->Text R@1={g1:.3f} R@5={g5:.3f} R@10={g10:.3f}")
|
| 56 |
+
|
| 57 |
+
if __name__=="__main__":
|
| 58 |
+
main()
|
| 59 |
+
|
src/preprocessing.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
import datasets
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import requests
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class Preprocessing():
|
| 10 |
+
def __init__(self):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def load_dataset(self,split):
|
| 14 |
+
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "10500"
|
| 15 |
+
dataset = load_dataset("lmms-lab/COCO-Caption", split=split, cache_dir="D:/Java Projects/Findr/Server/datasets")
|
| 16 |
+
ds = dataset.filter(lambda x: x['image'] is not None and x['question_id'] is not None and len(x['answer']) > 0)
|
| 17 |
+
return ds
|
| 18 |
+
|
| 19 |
+
def image_caption_pairs(self,ds):
|
| 20 |
+
import random
|
| 21 |
+
for data in ds:
|
| 22 |
+
img:Image.Image=data['image'].convert('RGB')
|
| 23 |
+
cap=random.choice(data['answer']).strip()
|
| 24 |
+
print(img,cap)
|
| 25 |
+
yield img,cap
|
| 26 |
+
|
| 27 |
+
if __name__=="__main__":
|
| 28 |
+
obj=Preprocessing()
|
| 29 |
+
obj.load_dataset('val')
|
src/training.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,random,math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import open_clip
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from src.preprocessing import Preprocessing
|
| 10 |
+
from torch.utils.data import DataLoader,Dataset
|
| 11 |
+
import warnings
|
| 12 |
+
import base64
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 15 |
+
|
| 16 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
+
torch.cuda.empty_cache()
|
| 18 |
+
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
|
| 19 |
+
SAVE_DIR='model/clip/best.pt'
|
| 20 |
+
tokenizer=open_clip.get_tokenizer('ViT-B-32')
|
| 21 |
+
|
| 22 |
+
def seed_everything(seed=42):
|
| 23 |
+
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
|
| 24 |
+
|
| 25 |
+
class clip_dataset(torch.utils.data.Dataset):
|
| 26 |
+
|
| 27 |
+
def __init__(self,split='val',processor=None,tokenizer=None):
|
| 28 |
+
preprocessor=Preprocessing()
|
| 29 |
+
self.ds=preprocessor.load_dataset(split=split)
|
| 30 |
+
self.tokenizer=tokenizer
|
| 31 |
+
self.processor=processor
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.ds)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self,index):
|
| 37 |
+
data=self.ds[index]
|
| 38 |
+
img:Image.Image=data['image'].convert('RGB')
|
| 39 |
+
text=random.choice(data['answer']).strip()
|
| 40 |
+
image=self.processor(img) if self.processor else img
|
| 41 |
+
token_text=self.tokenizer([text])[0]
|
| 42 |
+
return image,token_text
|
| 43 |
+
|
| 44 |
+
def clip_loss(image_features, text_features, temperature):
|
| 45 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 46 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 47 |
+
logits_per_image = (image_features @ text_features.t()) * torch.exp(temperature)
|
| 48 |
+
logits_per_text = logits_per_image.t()
|
| 49 |
+
targets = torch.arange(image_features.size(0), device=image_features.device)
|
| 50 |
+
loss_i = nn.CrossEntropyLoss()(logits_per_image, targets)
|
| 51 |
+
loss_t = nn.CrossEntropyLoss()(logits_per_text, targets)
|
| 52 |
+
return (loss_i + loss_t) / 2
|
| 53 |
+
|
| 54 |
+
def collate(batch):
|
| 55 |
+
imgs, toks = zip(*batch)
|
| 56 |
+
imgs = torch.stack(imgs, 0)
|
| 57 |
+
toks = torch.stack(toks, 0)
|
| 58 |
+
return imgs, toks
|
| 59 |
+
|
| 60 |
+
def train(arch='ViT-B-32',pretrained='openai',batchSize=2,epochs=5,lr=5e-5,warmup_steps=200,grad_accum=1,output_dir='model/clip'):
|
| 61 |
+
seed_everything(42)
|
| 62 |
+
torch.cuda.empty_cache()
|
| 63 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 64 |
+
|
| 65 |
+
tokenizer=open_clip.get_tokenizer(arch)
|
| 66 |
+
|
| 67 |
+
train_ds=clip_dataset(split='val',processor=preprocess,tokenizer=tokenizer)
|
| 68 |
+
val_ds=clip_dataset(split='test',processor=preprocess,tokenizer=tokenizer)
|
| 69 |
+
|
| 70 |
+
train_dl = DataLoader(train_ds, batch_size=batchSize, shuffle=True, num_workers=4, collate_fn=collate, pin_memory=True)
|
| 71 |
+
val_dl = DataLoader(val_ds, batch_size=batchSize, shuffle=False, num_workers=4, collate_fn=collate, pin_memory=True)
|
| 72 |
+
|
| 73 |
+
total_steps = epochs * math.ceil(len(train_dl) / grad_accum)
|
| 74 |
+
def lr_lambda(step):
|
| 75 |
+
if step < warmup_steps:
|
| 76 |
+
return (step + 1) / max(1, warmup_steps)
|
| 77 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 78 |
+
return 0.5 * (1 + math.cos(math.pi * progress))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
|
| 82 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 83 |
+
|
| 84 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(device.startswith("cuda")))
|
| 85 |
+
best_val = float("inf")
|
| 86 |
+
|
| 87 |
+
for epoch in range(1,epochs+1):
|
| 88 |
+
model.train()
|
| 89 |
+
running = 0.0
|
| 90 |
+
step = 0
|
| 91 |
+
pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}")
|
| 92 |
+
optimizer.zero_grad(set_to_none=True)
|
| 93 |
+
for images, tokens in pbar:
|
| 94 |
+
images = images.to(device, non_blocking=True)
|
| 95 |
+
tokens = tokens.to(device, non_blocking=True)
|
| 96 |
+
|
| 97 |
+
with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))):
|
| 98 |
+
image_features = model.encode_image(images)
|
| 99 |
+
text_features = model.encode_text(tokens)
|
| 100 |
+
loss = clip_loss(image_features, text_features, model.logit_scale)
|
| 101 |
+
|
| 102 |
+
scaler.scale(loss / grad_accum).backward()
|
| 103 |
+
step += 1
|
| 104 |
+
running += loss.item()
|
| 105 |
+
if step % grad_accum == 0:
|
| 106 |
+
scaler.step(optimizer); scaler.update()
|
| 107 |
+
optimizer.zero_grad(set_to_none=True)
|
| 108 |
+
scheduler.step()
|
| 109 |
+
|
| 110 |
+
pbar.set_postfix(loss=running / step, lr=optimizer.param_groups[0]["lr"])
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
model.eval()
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
val_losses = []
|
| 116 |
+
for images, tokens in tqdm(val_dl, leave=False, desc="Val"):
|
| 117 |
+
images = images.to(device); tokens = tokens.to(device)
|
| 118 |
+
with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))):
|
| 119 |
+
image_features = model.encode_image(images)
|
| 120 |
+
text_features = model.encode_text(tokens)
|
| 121 |
+
val_loss = clip_loss(image_features, text_features, model.logit_scale)
|
| 122 |
+
val_losses.append(val_loss.item())
|
| 123 |
+
val_mean = sum(val_losses)/len(val_losses)
|
| 124 |
+
|
| 125 |
+
ckpt_path = os.path.join(output_dir, f"epoch{epoch}_val{val_mean:.4f}.pt")
|
| 126 |
+
torch.save({"model": model.state_dict()}, ckpt_path)
|
| 127 |
+
if val_mean < best_val:
|
| 128 |
+
best_val = val_mean
|
| 129 |
+
torch.save({"model": model.state_dict()}, os.path.join(output_dir, "best.pt"))
|
| 130 |
+
|
| 131 |
+
print(f"Epoch {epoch} done. TrainLoss ~{running/step:.4f} ValLoss {val_mean:.4f}")
|
| 132 |
+
|
| 133 |
+
class FeedbackDataset(Dataset):
|
| 134 |
+
def __init__(self, examples, processor=None):
|
| 135 |
+
self.examples = examples
|
| 136 |
+
self.processor = processor
|
| 137 |
+
|
| 138 |
+
def __len__(self):
|
| 139 |
+
return len(self.examples)
|
| 140 |
+
|
| 141 |
+
def __getitem__(self, idx):
|
| 142 |
+
ex = self.examples[idx]
|
| 143 |
+
image = ex["image"]
|
| 144 |
+
if not isinstance(image, Image.Image):
|
| 145 |
+
image = Image.open(image).convert("RGB")
|
| 146 |
+
return image, ex["text"], ex["label"]
|
| 147 |
+
|
| 148 |
+
def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
|
| 149 |
+
dataset=FeedbackDataset(data,processor=processor)
|
| 150 |
+
dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
|
| 151 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 152 |
+
loss_fn = nn.CosineEmbeddingLoss()
|
| 153 |
+
model.load_state_dict(torch.load(SAVE_DIR, map_location=device))
|
| 154 |
+
model.train()
|
| 155 |
+
for epoch in range(epochs):
|
| 156 |
+
total_loss = 0
|
| 157 |
+
for images, texts, labels in dataLoader:
|
| 158 |
+
inputs = processor(text=texts, images=images,
|
| 159 |
+
return_tensors="pt", padding=True).to(device)
|
| 160 |
+
text_embeds = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])
|
| 161 |
+
image_embeds = model.get_image_features(inputs["pixel_values"])
|
| 162 |
+
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
| 163 |
+
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 164 |
+
|
| 165 |
+
labels = torch.tensor(labels, dtype=torch.float, device=device)
|
| 166 |
+
loss = loss_fn(image_embeds, text_embeds, labels)
|
| 167 |
+
|
| 168 |
+
optimizer.zero_grad()
|
| 169 |
+
loss.backward()
|
| 170 |
+
optimizer.step()
|
| 171 |
+
|
| 172 |
+
total_loss += loss.item()
|
| 173 |
+
|
| 174 |
+
print(f"{epoch+1}/{epochs} , Loss :{total_loss/len(dataLoader):.4f}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def encode_img_and_text(imgs,text):
|
| 179 |
+
image_feat=[]
|
| 180 |
+
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
|
| 181 |
+
checkpoint = torch.load(SAVE_DIR, map_location=device)
|
| 182 |
+
model.to(device)
|
| 183 |
+
for img in imgs:
|
| 184 |
+
if hasattr(img, 'read'):
|
| 185 |
+
image = Image.open(img.stream).convert("RGB")
|
| 186 |
+
else:
|
| 187 |
+
if isinstance(img, dict) and 'preview' in img:
|
| 188 |
+
img_data = img['preview'].split(",")[1]
|
| 189 |
+
image = Image.open(BytesIO(base64.b64decode(img_data))).convert("RGB")
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError("Unsupported image input")
|
| 192 |
+
image_input = preprocess(image).unsqueeze(0).to(device)
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
image_features = model.encode_image(image_input)
|
| 195 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 196 |
+
image_feat.append(image_features)
|
| 197 |
+
image_embedding=torch.stack(image_feat).mean(dim=0)
|
| 198 |
+
text_tokens=tokenizer([text]).to(device)
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
text_features = model.encode_text(text_tokens)
|
| 201 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 202 |
+
alpha=0.7
|
| 203 |
+
combined=alpha*image_embedding+(1-alpha)*text_features
|
| 204 |
+
combined=combined/combined.norm(dim=-1,keepdim=True)
|
| 205 |
+
return combined.squeeze(0).cpu().tolist()
|
| 206 |
+
|