PrashantGoyal commited on
Commit
11e7313
·
0 Parent(s):

Deploy backend to Hugging Face

Browse files
Files changed (14) hide show
  1. .env +16 -0
  2. .gitignore +168 -0
  3. App/__init__.py +0 -0
  4. App/app.py +447 -0
  5. App/models.py +56 -0
  6. App/scheduler.py +195 -0
  7. Dockerfile +27 -0
  8. requirements.txt +19 -0
  9. requirements_scheduler.txt +13 -0
  10. setup.py +25 -0
  11. src/__init__.py +0 -0
  12. src/evaluation.py +59 -0
  13. src/preprocessing.py +29 -0
  14. src/training.py +206 -0
.env ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ DATABASE_URL=https://mrlyvrpxsumashqzcmhd.supabase.co
3
+ SUPABASE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Im1ybHl2cnB4c3VtYXNocXpjbWhkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjE1NzQ5ODAsImV4cCI6MjA3NzE1MDk4MH0.HoD5V3nXSGnLFSbjcqveBn7LUZZPS4KUTEuM3eoQ2uQ
4
+
5
+ JWT_SECRET_KEY=Hello_buddy
6
+ SECRET_KEY=hello_buddy
7
+ CLIENT=http://localhost:3000
8
+ Qdrant_api_key=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.9ZKSLbgV0_jyesC_fqhyuMHS2XRacAKa-jYeo01vCng
9
+ Qdrant_url=https://db450db1-2425-4e2e-b839-b1c585defee3.europe-west3-0.gcp.cloud.qdrant.io
10
+ Qdrant_Collection=FINDR
11
+
12
+ CLOUDINARY_API_SECRET=oaBBFHmbY7i6GNs8q_auVcwd5OM
13
+ CLOUDINARY_API_KEY=456896227428735
14
+ CLOUDINARY_CLIENT_NAME=dc728fl24
15
+ CLOUDINARY_URL=cloudinary://456896227428735:oaBBFHmbY7i6GNs8q_auVcwd5OM@dc728fl24
16
+ scheduler=https://findr-ai-scheduler.onrender.com
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ /get-pip.py
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111
+ .pdm.toml
112
+ .pdm-python
113
+ .pdm-build/
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ Server/.env
127
+ Client/.env
128
+ .venv
129
+ env/
130
+ venv/
131
+ new_env/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+ /datasets
160
+ /.dockerignore
161
+ /start.sh
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
App/__init__.py ADDED
File without changes
App/app.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import Flask,request,jsonify,make_response
3
+ from flask_bcrypt import Bcrypt
4
+ from functools import wraps
5
+ from supabase import create_client
6
+ from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
7
+ from App.models import User,LostItem,FoundItem,Match
8
+ from dotenv import load_dotenv
9
+ from flask_cors import CORS
10
+ from src.training import encode_img_and_text
11
+ from qdrant_client import QdrantClient
12
+ from qdrant_client.http import models
13
+ import cloudinary
14
+ from cloudinary import uploader
15
+ import warnings
16
+ import base64
17
+ from io import BytesIO
18
+ import threading
19
+ from datetime import timedelta
20
+ warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
21
+
22
+
23
+ load_dotenv()
24
+
25
+ app = Flask(__name__)
26
+ app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
27
+ app.config["SECRET_KEY"] = os.getenv("SECRET_KEY")
28
+ app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY")
29
+ app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
30
+ app.config["JWT_ACCESS_COOKIE_NAME"] = "Token"
31
+ app.config["JWT_COOKIE_SAMESITE"] = "None"
32
+ app.config["JWT_COOKIE_SECURE"] = False
33
+ app.config["JWT_COOKIE_DOMAIN"] = ".localhost"
34
+ app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(days=3)
35
+
36
+
37
+ qdrant=QdrantClient(
38
+ url=os.getenv("Qdrant_url"),
39
+ api_key=os.getenv("Qdrant_api_key"),
40
+ )
41
+
42
+ cloudinary.config(
43
+ cloud_name=os.getenv("CLOUDINARY_CLIENT_NAME"),
44
+ api_key=os.getenv("CLOUDINARY_API_KEY"),
45
+ api_secret=os.getenv("CLOUDINARY_API_SECRET"),
46
+ )
47
+
48
+
49
+ db = create_client(os.getenv("DATABASE_URL"),os.getenv("SUPABASE_KEY"))
50
+ bcrypt=Bcrypt(app)
51
+ jwt=JWTManager(app)
52
+ CORS(app, supports_credentials=True, origins=[os.getenv("CLIENT")])
53
+
54
+
55
+ print('Qdrant connected')
56
+ print("Posgres Connected")
57
+
58
+
59
+
60
+ def decode_jwt(fn):
61
+ @wraps(fn)
62
+ def wrapper(*args, **kwargs):
63
+ token=request.cookies.get("Token")
64
+ if not token:
65
+ print('no token found')
66
+ return jsonify({
67
+ "message":"No Token Found"
68
+ })
69
+ else:
70
+ id=decode_token(token)["sub"]
71
+ print(id)
72
+ return fn(user_id=id, *args, **kwargs)
73
+ return wrapper
74
+
75
+ class DotDict(dict):
76
+ def __getattr__(self, key):
77
+ try:
78
+ return self[key]
79
+ except KeyError:
80
+ raise AttributeError(f"No such attribute: {key}")
81
+
82
+ __setattr__ = dict.__setitem__
83
+ __delattr__ = dict.__delitem__
84
+
85
+
86
+
87
+
88
+
89
+ def upload_img(img):
90
+ return uploader.upload(img, resource_type="image")["secure_url"]
91
+
92
+ @app.route('/register',methods=["POST"])
93
+ def register():
94
+ try:
95
+ print("received")
96
+ data=request.get_json()
97
+ first_name=data['firstName']
98
+ last_name=data['lastName']
99
+ email=data['email']
100
+ phone=data['phone']
101
+ password=data['password']
102
+ if db.table("users").select("*").eq("email",email).limit(1).execute().data:
103
+
104
+ return jsonify({
105
+ "success":False,
106
+ "error":"User Already Exist"
107
+ })
108
+ hashing=bcrypt.generate_password_hash(password).decode("utf-8")
109
+ new_user=db.table("users").insert({"first_name":first_name,"last_name":last_name,"email":email,"phone":phone,"password":hashing}).execute().data[0]
110
+
111
+ token=create_access_token(identity=str(new_user.get("id")), expires_delta=timedelta(days=3))
112
+ res=make_response({
113
+ "success":True,
114
+ "message":"User Registered Successfully",
115
+
116
+ })
117
+ res.set_cookie("Token",token,httponly=True,secure=True,samesite="None",max_age=259200,domain=".localhost")
118
+ return res,200
119
+ except Exception as e:
120
+ print(e)
121
+ return jsonify({
122
+ "sucsess":False,
123
+ "message":"Internal Server Error"
124
+ }),400
125
+
126
+ @app.route("/login",methods=["POST"])
127
+ def login():
128
+ try:
129
+ data=request.get_json()
130
+ print(data)
131
+ email=data['email']
132
+ password=data['password']
133
+ user= db.table("users").select("*").eq("email",email).limit(1).execute().data[0]
134
+ if not user or not bcrypt.check_password_hash(user.get("password"),password):
135
+ return jsonify({
136
+ "sucsess":False,
137
+ "message":"No User Found"
138
+ }),200
139
+ token=create_access_token(identity=str(user.get("id")), expires_delta=timedelta(days=3))
140
+ res=make_response({
141
+ "success":True,
142
+ "message":"User Login Successfully",
143
+
144
+ })
145
+ res.set_cookie("Token",token,httponly=True,secure=True,samesite="None",max_age=259200,domain=".localhost")
146
+ return res,200
147
+ except Exception as e:
148
+ print(e)
149
+ return jsonify({
150
+ "sucsess":False,
151
+ "message":"Internal Server Error"
152
+ }),400
153
+
154
+
155
+ @app.route('/logout',methods=["GET"])
156
+ def logout():
157
+ res=make_response({
158
+ "success":True,
159
+ "message":"Logout Successfully"
160
+ })
161
+ unset_jwt_cookies(res)
162
+ return res,200
163
+
164
+ @app.route('/get_user',methods=['GET'])
165
+ @decode_jwt
166
+ def get_user(user_id):
167
+ user=user= db.table("users").select("*").eq("id",user_id).limit(1).execute().data[0]
168
+ return jsonify({
169
+ "success":True,
170
+ "user":{"first_name": user.first_name,
171
+ "last_name": user.last_name,
172
+ "email": user.email,
173
+ "phone_number": user.phone}
174
+ }), 200
175
+
176
+
177
+ @app.route('/lostItem',methods=['POST'])
178
+ @decode_jwt
179
+ def lostItem(user_id):
180
+ print('search start')
181
+ try:
182
+ imgs_data=request.files.getlist('item')
183
+ img_urls=[upload_img(img) for img in imgs_data]
184
+ description=request.form.get('description')
185
+ print(imgs_data, description)
186
+ vector=encode_img_and_text(imgs_data,description)
187
+ lastSeenLocation=request.form.get('lastSeenLocation')
188
+ dateTimeLost=request.form.get('dateTimeLost')
189
+ name=request.form.get('name')
190
+ email=request.form.get('email')
191
+ phone=request.form.get('phone')
192
+ reward=request.form.get('reward')
193
+ additionalNotes=request.form.get('additionalNotes')
194
+ item=db.table("lostItem").insert({"user_id":user_id,"name":name,"email":email,"phone":phone,"description":description,"lastSeenLocation":lastSeenLocation,"dateTimeLost":dateTimeLost,"reward":reward,"additionalNotes":additionalNotes,"image_url": img_urls}).execute().data[0]
195
+ print('db save', len(vector))
196
+
197
+ collections = qdrant.get_collections().collections
198
+ existing_names = [c.name for c in collections]
199
+
200
+ if "lost_items" not in existing_names:
201
+ qdrant.create_collection(
202
+ collection_name="lost_items",
203
+ vectors_config=models.VectorParams(
204
+ size=512,
205
+ distance=models.Distance.COSINE
206
+ )
207
+ )
208
+
209
+ qdrant.upsert(
210
+ collection_name="lost_items",
211
+ points=[
212
+ models.PointStruct(id=item.get("id"), vector=vector, payload={"description":description,"place_lost":lastSeenLocation,"status" : "active"})
213
+ ],
214
+ )
215
+ print('vector save')
216
+ return jsonify({
217
+ "success":True
218
+ }),200
219
+ except Exception as e:
220
+ import traceback
221
+ traceback.print_exc()
222
+ return jsonify({
223
+ "success":False
224
+ }),400
225
+
226
+
227
+ @app.route('/foundItem',methods=['POST'])
228
+ @decode_jwt
229
+ def foundItem(user_id):
230
+ print('search start')
231
+ try:
232
+ imgs_data=request.files.getlist('item')
233
+ img_urls=[upload_img(img) for img in imgs_data]
234
+ description=request.form.get('description')
235
+ print(imgs_data, description)
236
+ vector=encode_img_and_text(imgs_data,description)
237
+ found_near=request.form.get('found_near')
238
+ name=request.form.get('name')
239
+ email=request.form.get('email')
240
+ phone=request.form.get('phone')
241
+ item=db.table("foundItem").insert({"user_id":user_id,"name":name,"email":email,"phone":phone,"description":description,"found_near":found_near,"image_url": img_urls}).execute().data[0]
242
+
243
+ print('db save', len(vector))
244
+
245
+ collections = qdrant.get_collections().collections
246
+ existing_names = [c.name for c in collections]
247
+
248
+ if "found_items" not in existing_names:
249
+ qdrant.create_collection(
250
+ collection_name="found_items",
251
+ vectors_config=models.VectorParams(
252
+ size=512,
253
+ distance=models.Distance.COSINE
254
+ )
255
+ )
256
+
257
+ qdrant.upsert(
258
+ collection_name="found_items",
259
+ points=[
260
+ models.PointStruct(id=item.get("id"), vector=vector, payload={"description":description,"place_found":found_near,"status" : "active"})
261
+ ],
262
+ )
263
+ print('vector save')
264
+ return jsonify({
265
+ "success":True
266
+ }),200
267
+ except Exception as e:
268
+ import traceback
269
+ traceback.print_exc()
270
+ return jsonify({
271
+ "success":False
272
+ }),200
273
+
274
+ @app.route('/allLostItems',methods=['GET'])
275
+ @decode_jwt
276
+ def allLostItems(user_id):
277
+
278
+ rows = db.table("lostItem").select("*").eq("user_id",user_id).execute().data
279
+ items = [DotDict(r) for r in rows]
280
+ output = []
281
+ for item in items[0]:
282
+ output.append({
283
+ "id": item.id,
284
+ "name": item.name,
285
+ "email": item.email,
286
+ "phone": item.phone,
287
+ "description": item.description,
288
+ "lastSeenLocation": item.lastSeenLocation,
289
+ "dateTimeLost": item.dateTimeLost,
290
+ "reward": item.reward,
291
+ "additionalNotes": item.additionalNotes,
292
+ "image_url": item.image_url,
293
+ "status": item.status,
294
+ "created_at": item.created_at
295
+ })
296
+ return jsonify({
297
+ "success":True,
298
+ "lostItems":output
299
+ }),200
300
+
301
+ @app.route('/matchLost/<lost_id>',methods=['GET'])
302
+ def matchLost(lost_id):
303
+ items = db.table("lostItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
304
+ found_items=[]
305
+ if not items.get("found_items"):
306
+ return jsonify({
307
+ "success":False,
308
+ "message":"No Lost Item Found"
309
+ }),400
310
+ for ids in items.get("found_items"):
311
+ found_item=db.table("foundItem").select("*").eq("id",int(ids)).limit(1).execute().data[0]
312
+ if found_item:
313
+ found_items.append({
314
+ "id": found_item.get("id"),
315
+ "name": found_item.get("name"),
316
+ "email": found_item.get("email"),
317
+ "phone": found_item.get("phone"),
318
+ "description": found_item.get("description"),
319
+ "found_near": found_item.get("found_near"),
320
+ "image_url": found_item.get("image_url"),
321
+ "status": found_item.get("status"),
322
+ "created_at": found_item.get("created_at")
323
+ })
324
+
325
+ return jsonify({
326
+ "success":True,
327
+ "foundItems": found_items
328
+ }),200
329
+
330
+
331
+
332
+ @app.route('/matchFound/<lost_id>',methods=['GET'])
333
+ def matchFound(lost_id):
334
+ items = db.table("foundItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
335
+
336
+ found_items=[]
337
+ if not items.get("lost_items"):
338
+ return jsonify({
339
+ "success":False,
340
+ "message":"No Found Item Found"
341
+ }),200
342
+ for ids in items.get("lost_items"):
343
+ found_item=db.table("lostItem").select("*").eq("id",int(ids)).limit(1).execute().data[0]
344
+ if found_item:
345
+ found_items.append({
346
+ "id": found_item.get("id"),
347
+ "name": found_item.get("name"),
348
+ "email": found_item.get("email"),
349
+ "phone": found_item.get("phone"),
350
+ "description": found_item.get("description"),
351
+ "found_near": found_item.get("lastSeenLocation"),
352
+ "image_url": found_item.get("image_url"),
353
+ "status": found_item.get("status"),
354
+ "created_at": found_item.get("created_at")
355
+ })
356
+
357
+ return jsonify({
358
+ "success":True,
359
+ "foundItems": found_items
360
+ }),200
361
+
362
+
363
+ @app.route('/lostMatchDetail/<lost_id>',methods=['GET'])
364
+ def lostMatchDetail(lost_id):
365
+ found_item = db.table("lostItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
366
+ found_items=[]
367
+ if not found_item:
368
+ return jsonify({
369
+ "success":False,
370
+ "message":"No Found Item Found"
371
+ }),400
372
+
373
+ found_items.append({
374
+ "id": found_item.get("id"),
375
+ "name": found_item.get("name"),
376
+ "email": found_item.get("email"),
377
+ "phone": found_item.get("phone"),
378
+ "description": found_item.get("description"),
379
+ "found_near": found_item.get("lastSeenLocation"),
380
+ "image_url": found_item.get("image_url"),
381
+ "status": found_item.get("status"),
382
+ "date_lost":found_item.get("dateTimeLost"),
383
+ "reward":found_item.get("reward"),
384
+ "additional_notes":found_item.get("additionalNotes"),
385
+ "created_at": found_item.get("created_at")
386
+ })
387
+
388
+
389
+ return jsonify({
390
+ "success":True,
391
+ "foundItems": found_items
392
+ }),200
393
+
394
+
395
+ @app.route('/allFoundItems',methods=['GET'])
396
+ @decode_jwt
397
+ def allFoundItems(user_id):
398
+ items = db.table("foundItem").select("*").eq("user_id",user_id).execute().data[0]
399
+ output = []
400
+ for item in items:
401
+ output.append({
402
+ "id": item.get("id"),
403
+ "name": item.get("name"),
404
+ "email": item.get("email"),
405
+ "phone": item.get("phone"),
406
+ "description": item.get("description"),
407
+ "found_near": item.get("found_near"),
408
+ "image_url": item.get("image_url"),
409
+ "status": item.get("status"),
410
+ "created_at": item.get("created_at")
411
+ })
412
+ return jsonify({
413
+ "success":True,
414
+ "foundItems":output
415
+ }),200
416
+
417
+ @app.route('/foundMatchDetail/<lost_id>',methods=['GET'])
418
+ def foundMatchDetail(lost_id):
419
+ found_item = db.table("foundItem").select("*").eq("id",lost_id).limit(1).execute().data[0]
420
+ print(lost_id)
421
+ found_items=[]
422
+ if not found_item:
423
+ return jsonify({
424
+ "success":False,
425
+ "message":"No Found Item Found"
426
+ }),400
427
+
428
+ found_items.append({
429
+ "id": found_item.get("id"),
430
+ "name": found_item.get("name"),
431
+ "email": found_item.get("email"),
432
+ "phone": found_item.get("phone"),
433
+ "description": found_item.get("description"),
434
+ "found_near": found_item.get("found_near"),
435
+ "image_url": found_item.get("image_url"),
436
+ "status": found_item.get("status"),
437
+ "created_at": found_item.get("created_at")
438
+ })
439
+
440
+
441
+ return jsonify({
442
+ "success":True,
443
+ "foundItems": found_items
444
+ }),200
445
+
446
+ if __name__ == "__main__":
447
+ app.run(debug=True,port=8000)
App/models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask_sqlalchemy import SQLAlchemy
2
+ import uuid
3
+ from sqlalchemy.dialects.postgresql import JSON,ARRAY
4
+ from sqlalchemy.ext.mutable import MutableList
5
+
6
+ db = SQLAlchemy()
7
+
8
+ class User(db.Model):
9
+ __tablename__ = "users"
10
+
11
+ id = db.Column(db.Integer, primary_key=True)
12
+ first_name = db.Column(db.String(50), nullable=False)
13
+ last_name = db.Column(db.String(50), nullable=False)
14
+ email = db.Column(db.String(100), unique=True, nullable=False)
15
+ phone=db.Column(db.String(15), nullable=False)
16
+ password = db.Column(db.String(200), nullable=False)
17
+
18
+
19
+ class LostItem(db.Model):
20
+ __tablename__="lostItem"
21
+ id = db.Column(db.Integer, primary_key=True, index=True)
22
+ user_id = db.Column(db.Integer, db.ForeignKey("users.id"),nullable=False)
23
+ name=db.Column(db.String,nullable=False)
24
+ email=db.Column(db.String,nullable=False)
25
+ phone=db.Column(db.String,nullable=False)
26
+ description = db.Column(db.Text, nullable=False)
27
+ lastSeenLocation = db.Column(db.Text, nullable=False)
28
+ dateTimeLost = db.Column(db.Text, nullable=False)
29
+ reward = db.Column(db.Text)
30
+ additionalNotes = db.Column(db.Text)
31
+ image_url = db.Column(JSON, nullable=False)
32
+ status=db.Column(db.String,nullable=False,default='active')
33
+ found_items=db.Column(MutableList.as_mutable(ARRAY(db.String)),nullable=False)
34
+ created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
35
+
36
+ class FoundItem(db.Model):
37
+ __tablename__="foundItem"
38
+ id = db.Column(db.Integer, primary_key=True, index=True)
39
+ user_id = db.Column(db.Integer, db.ForeignKey("users.id"),nullable=False)
40
+ name=db.Column(db.String,nullable=False)
41
+ email=db.Column(db.String,nullable=False)
42
+ phone=db.Column(db.String,nullable=False)
43
+ description = db.Column(db.Text, nullable=False)
44
+ found_near= db.Column(db.Text, nullable=False)
45
+ image_url = db.Column(JSON, nullable=False)
46
+ status=db.Column(db.String,nullable=False,default='active')
47
+ lost_items=db.Column(MutableList.as_mutable(ARRAY(db.String)),nullable=False)
48
+ created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
49
+
50
+ class Match(db.Model):
51
+ __tablename__="matches"
52
+ id = db.Column(db.Integer, primary_key=True, index=True)
53
+ lost_item_id = db.Column(db.Integer, db.ForeignKey("lostItem.id"),nullable=False)
54
+ found_item_id = db.Column(db.Integer, db.ForeignKey("foundItem.id"),nullable=False)
55
+ confidence_score=db.Column(db.Float,nullable=False)
56
+ created_at = db.Column(db.TIMESTAMP, server_default=db.func.now(),nullable=False)
App/scheduler.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import Flask,request,jsonify,make_response
3
+ from flask_bcrypt import Bcrypt
4
+ from functools import wraps
5
+ from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
6
+ from dotenv import load_dotenv
7
+ from flask_cors import CORS
8
+ from qdrant_client import QdrantClient
9
+ from qdrant_client.http import models
10
+ from datetime import datetime, timedelta
11
+ from qdrant_client.models import PayloadSchemaType
12
+ from supabase import create_client
13
+
14
+ load_dotenv()
15
+
16
+ app = Flask(__name__)
17
+ app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
18
+
19
+ db = create_client(os.getenv("DATABASE_URL"),os.getenv("SUPABASE_KEY"))
20
+
21
+
22
+ qdrant=QdrantClient(
23
+ url=os.getenv("Qdrant_url"),
24
+ api_key=os.getenv("Qdrant_api_key"),
25
+ )
26
+ qdrant.create_payload_index(
27
+ collection_name="found_items",
28
+ field_name="status",
29
+ field_schema=PayloadSchemaType.KEYWORD
30
+ )
31
+
32
+ DEFAULT_MIN_AGE_MINUTES = 60
33
+ DEFAULT_CONFIDENCE = 0.78
34
+ DEFAULT_TOP_K = 5
35
+
36
+ CORS(app, supports_credentials=True, origins=[os.getenv("CLIENT")])
37
+
38
+ print('Qdrant connected')
39
+ print("Posgres Connected")
40
+
41
+
42
+ def iso_now_minus(minutes):
43
+ return (datetime.utcnow() - timedelta(minutes=minutes)).isoformat()
44
+
45
+
46
+ @app.route("/admin/match-active-lost", methods=["POST"])
47
+ def match_active_lost():
48
+ body = request.get_json(silent=True) or {}
49
+ min_age_minutes = int(body.get("min_age_minutes", DEFAULT_MIN_AGE_MINUTES))
50
+ confidence_threshold = float(body.get("confidence_threshold", DEFAULT_CONFIDENCE))
51
+ top_k = int(body.get("top_k", DEFAULT_TOP_K))
52
+
53
+ cutoff = iso_now_minus(min_age_minutes)
54
+
55
+ lost_rows = (db.table("lostItem").select("*").eq("status","active").execute())
56
+ lost_items = [
57
+ {
58
+ "id": r.get("id"),
59
+ "user_id": r.get("user_id"),
60
+ "created_at": r.get("created_at")
61
+ }
62
+
63
+
64
+ for r in lost_rows.data
65
+ ]
66
+ if not lost_items:
67
+ return jsonify({"message": "No eligible lost items found", "checked": 0}), 200
68
+
69
+ created_matches = []
70
+ errors = []
71
+
72
+ for lost in lost_items:
73
+ lost_id=int(lost["id"])
74
+ lost_user_id=str(lost["user_id"])
75
+ try:
76
+ point=qdrant.retrieve(collection_name="lost_items", ids=[lost_id], with_vectors=True)
77
+ if not point:
78
+ errors.append({"lost_id": lost_id, "error": "No vector found in Qdrant for this lost item"})
79
+ continue
80
+ lost_vector = point[0].vector
81
+ except Exception as e:
82
+ errors.append({"lost_id": lost_id, "error": f"Qdrant get_point failed: {str(e)}"})
83
+ continue
84
+
85
+ qfilter = models.Filter(
86
+ must=[models.FieldCondition(key="status", match=models.MatchValue(value="active"))]
87
+ )
88
+
89
+ try:
90
+ results = qdrant.search(
91
+ collection_name="found_items",
92
+ query_vector=lost_vector,
93
+ limit=top_k,
94
+ query_filter=qfilter
95
+ )
96
+ except Exception as e:
97
+ errors.append({"lost_id": lost_id, "error": f"Qdrant search failed: {str(e)}"})
98
+ continue
99
+
100
+ for r in results:
101
+ score = float(r.score) if r.score is not None else None
102
+ if score is None:
103
+ continue
104
+ if score < confidence_threshold:
105
+ continue
106
+
107
+ found_point_payload = r.payload or {}
108
+ found_supabase_id = found_point_payload.get("id") or r.id
109
+
110
+ if lost_user_id == str(found_point_payload.get("user_id")):
111
+ continue
112
+
113
+ try:
114
+ existing_match=db.table("matches").select("lost_item_id","found_item_id").eq("found_item_id", found_supabase_id).eq("lost_item_id", lost_id).limit(1).execute()
115
+ if existing_match:
116
+ continue
117
+ except Exception as e:
118
+ errors.append({"lost_id": lost_id, "error": f"Database query failed: {str(e)}"})
119
+ continue
120
+
121
+
122
+ try:
123
+ match_record = db.table("matches").insert({
124
+ "lost_item_id": lost_id,
125
+ "found_item_id": found_supabase_id,
126
+ "confidence_score": score
127
+ }).execute()
128
+ created_matches.append({"lost_id": lost_id, "found_id": found_supabase_id, "similarity": score})
129
+ except Exception as e:
130
+ errors.append({"lost_id": lost_id, "found_id": found_supabase_id, "error": f"Supabase insert failed: {str(e)}"})
131
+
132
+ matches=(db.table("matches").select("id","lost_item_id","found_item_id","confidence_score").execute())
133
+
134
+ match_items = [
135
+ {
136
+ "id": r.get("id"),
137
+ "lost_id": int(r.get("lost_item_id")),
138
+ "found_id": int(r.get("found_item_id")),
139
+ "confidence_score": r.get("confidence_score")
140
+ }
141
+ for r in matches.data
142
+ ]
143
+
144
+ for match in match_items:
145
+ already_exist_found=(db.table("lostItem").select("*").eq("user_id", match["lost_id"]).contains("found_items", [str(match["found_id"])]).limit(1).execute()
146
+ )
147
+ if already_exist_found:
148
+ continue
149
+ else:
150
+ res=(db.table("lostItem").select("*").eq("id", match["lost_id"]).limit(1).execute()
151
+ )
152
+ if not res.data:
153
+ found_item = None
154
+ else:
155
+ found_item = res.data[0]
156
+
157
+ current_lost_items = found_item.get("found_items") or []
158
+
159
+ if match["found_id"] not in current_lost_items:
160
+ current_lost_items.append(match["found_id"])
161
+
162
+ db.table("lostItem").update({"found_items": current_lost_items }).eq("id", match["lost_id"]).execute()
163
+
164
+ already_exist_found=( db.table("foundItem").select("*").eq("user_id", match["found_id"]).contains("lost_items", [str(match["lost_id"])]).limit(1).execute()
165
+ )
166
+ if already_exist_found:
167
+ continue
168
+
169
+ else:
170
+ res=(db.table("foundItem").select("lost_items").eq("id", match["found_id"]).limit(1).execute()
171
+ )
172
+ if not res.data:
173
+ found_item = None
174
+ else:
175
+ found_item = res.data[0]
176
+
177
+ current_lost_items = found_item.get("lost_items") or []
178
+
179
+ if match["lost_id"] not in current_lost_items:
180
+ current_lost_items.append(match["lost_id"])
181
+
182
+ db.table("foundItem").update({"lost_items": current_lost_items }).eq("id", match["found_id"]).execute()
183
+
184
+
185
+ return jsonify({
186
+ "checked_lost_count": len(lost_rows.data),
187
+ "created_matches_count": len(created_matches),
188
+ "created_matches": created_matches,
189
+ "errors": errors
190
+ }), 200
191
+
192
+
193
+
194
+ if __name__ == "__main__":
195
+ app.run(host="0.0.0.0",port=5000)
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /Server
4
+
5
+ # 🔴 REQUIRED system libraries for pyarrow / datasets
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ gcc \
9
+ g++ \
10
+ cmake \
11
+ curl \
12
+ libglib2.0-0 \
13
+ libsm6 \
14
+ libxext6 \
15
+ libxrender-dev \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ COPY requirements.txt .
19
+
20
+ # 🔴 Upgrade tooling and install deps
21
+ RUN python -m pip install --upgrade pip setuptools wheel \
22
+ && python -m pip install -r requirements.txt --no-cache-dir
23
+
24
+ COPY . .
25
+
26
+ CMD ["python", "-m", "App.app"]
27
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ tqdm
6
+ transformers
7
+ huggingface_hub
8
+ flask
9
+ flask-cors
10
+ flask-sqlalchemy
11
+ flask-bcrypt
12
+ flask-jwt-extended
13
+ psycopg2-binary
14
+ python-dotenv
15
+ cloudinary
16
+ qdrant-client
17
+ datasets
18
+ open-clip-torch
19
+ supabase
requirements_scheduler.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ flask-sqlalchemy
4
+ flask-bcrypt
5
+ flask-jwt-extended
6
+ psycopg2-binary
7
+ python-dotenv
8
+ cloudinary
9
+ qdrant-client
10
+ numpy
11
+ pillow
12
+ tqdm
13
+ supabase
setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup,find_packages
2
+ from typing import List
3
+
4
+ HYPHEN_E_DOT='-e .'
5
+
6
+ def get_requiremnets(filename : str)->List[str]:
7
+ requirements = []
8
+ with open(filename,'r') as f:
9
+ requirements=f.readlines()
10
+ requirements = [req.replace('\n','') for req in requirements]
11
+
12
+ if HYPHEN_E_DOT in requirements:
13
+ requirements.remove(HYPHEN_E_DOT)
14
+
15
+ return requirements
16
+
17
+ setup (
18
+ name = "FINDR",
19
+ version = "0.0.1",
20
+ packages = find_packages(),
21
+ author = "Prashant",
22
+ author_email = "prashant.goyal2002@gmail.com",
23
+ description = "a simple project to predict performance of student",
24
+ install_requires = get_requiremnets('requirements.txt')
25
+ )
src/__init__.py ADDED
File without changes
src/evaluation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,json
2
+ from training import clip_dataset
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ import open_clip
6
+ from datasets import load_dataset
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+ import warnings
11
+ warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
12
+
13
+ def collate(batch):
14
+ img,text=zip(*batch)
15
+ return torch.stack(img,0),torch.stack(text,0)
16
+
17
+ @torch.no_grad
18
+ def encode_img(model,processor,tokenizer,split,device):
19
+ ds=clip_dataset(split=split,processor=processor,tokenizer=tokenizer)
20
+ print('dataset Loaded')
21
+ dl=DataLoader(ds,batch_size=4,shuffle=False,num_workers=4,collate_fn=collate)
22
+ all_img,all_text=[],[]
23
+ for img,text in tqdm(dl,desc=f"Encode {split}"):
24
+ img=img.to(device)
25
+ text=text.to(device)
26
+ img_f=model.encode_image(img)
27
+ text_f=model.encode_text(text)
28
+ img_f=img_f/img_f.norm(keepdim=True,dim=-1)
29
+ text_f=text_f/text_f.norm(keepdim=True,dim=-1)
30
+ all_img.append(img_f.cpu())
31
+ all_text.append(text_f.cpu())
32
+ return torch.cat(all_img),torch.cat(all_text)
33
+
34
+
35
+ def gold_k(sims,k):
36
+ ranks = (-sims).argsort(axis=1)
37
+ hits = (ranks[:, :k] == np.arange(sims.shape[0])[:,None]).any(axis=1)
38
+ return hits.mean()
39
+
40
+
41
+ def main(path='./model/clip/best.pt',arch='ViT-B-32', pretrained='openai'):
42
+ device="cuda" if torch.cuda.is_available() else "cpu"
43
+ torch.cuda.empty_cache()
44
+ model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
45
+ tokenizer=open_clip.get_tokenizer(arch)
46
+ state=torch.load(path,map_location='cuda')['model']
47
+ model.load_state_dict(state, strict=False)
48
+ model.eval()
49
+ print('model loaded')
50
+ img_f,text_f=encode_img(model,processor=preprocess,tokenizer=tokenizer,split='test',device=device)
51
+ sim=(img_f@text_f.T).numpy()
52
+ g1=gold_k(sim,1)
53
+ g5=gold_k(sim,5)
54
+ g10=gold_k(sim,10)
55
+ print(f"Image->Text R@1={g1:.3f} R@5={g5:.3f} R@10={g10:.3f}")
56
+
57
+ if __name__=="__main__":
58
+ main()
59
+
src/preprocessing.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import datasets
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+ import torch
6
+ import requests
7
+ import os
8
+
9
+ class Preprocessing():
10
+ def __init__(self):
11
+ pass
12
+
13
+ def load_dataset(self,split):
14
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "10500"
15
+ dataset = load_dataset("lmms-lab/COCO-Caption", split=split, cache_dir="D:/Java Projects/Findr/Server/datasets")
16
+ ds = dataset.filter(lambda x: x['image'] is not None and x['question_id'] is not None and len(x['answer']) > 0)
17
+ return ds
18
+
19
+ def image_caption_pairs(self,ds):
20
+ import random
21
+ for data in ds:
22
+ img:Image.Image=data['image'].convert('RGB')
23
+ cap=random.choice(data['answer']).strip()
24
+ print(img,cap)
25
+ yield img,cap
26
+
27
+ if __name__=="__main__":
28
+ obj=Preprocessing()
29
+ obj.load_dataset('val')
src/training.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,random,math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from tqdm import tqdm
6
+ import open_clip
7
+ from datasets import load_dataset
8
+ from PIL import Image
9
+ from src.preprocessing import Preprocessing
10
+ from torch.utils.data import DataLoader,Dataset
11
+ import warnings
12
+ import base64
13
+ from io import BytesIO
14
+ warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
15
+
16
+ device='cuda' if torch.cuda.is_available() else 'cpu'
17
+ torch.cuda.empty_cache()
18
+ model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
19
+ SAVE_DIR='model/clip/best.pt'
20
+ tokenizer=open_clip.get_tokenizer('ViT-B-32')
21
+
22
+ def seed_everything(seed=42):
23
+ random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
24
+
25
+ class clip_dataset(torch.utils.data.Dataset):
26
+
27
+ def __init__(self,split='val',processor=None,tokenizer=None):
28
+ preprocessor=Preprocessing()
29
+ self.ds=preprocessor.load_dataset(split=split)
30
+ self.tokenizer=tokenizer
31
+ self.processor=processor
32
+
33
+ def __len__(self):
34
+ return len(self.ds)
35
+
36
+ def __getitem__(self,index):
37
+ data=self.ds[index]
38
+ img:Image.Image=data['image'].convert('RGB')
39
+ text=random.choice(data['answer']).strip()
40
+ image=self.processor(img) if self.processor else img
41
+ token_text=self.tokenizer([text])[0]
42
+ return image,token_text
43
+
44
+ def clip_loss(image_features, text_features, temperature):
45
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
46
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
47
+ logits_per_image = (image_features @ text_features.t()) * torch.exp(temperature)
48
+ logits_per_text = logits_per_image.t()
49
+ targets = torch.arange(image_features.size(0), device=image_features.device)
50
+ loss_i = nn.CrossEntropyLoss()(logits_per_image, targets)
51
+ loss_t = nn.CrossEntropyLoss()(logits_per_text, targets)
52
+ return (loss_i + loss_t) / 2
53
+
54
+ def collate(batch):
55
+ imgs, toks = zip(*batch)
56
+ imgs = torch.stack(imgs, 0)
57
+ toks = torch.stack(toks, 0)
58
+ return imgs, toks
59
+
60
+ def train(arch='ViT-B-32',pretrained='openai',batchSize=2,epochs=5,lr=5e-5,warmup_steps=200,grad_accum=1,output_dir='model/clip'):
61
+ seed_everything(42)
62
+ torch.cuda.empty_cache()
63
+ os.makedirs(output_dir,exist_ok=True)
64
+
65
+ tokenizer=open_clip.get_tokenizer(arch)
66
+
67
+ train_ds=clip_dataset(split='val',processor=preprocess,tokenizer=tokenizer)
68
+ val_ds=clip_dataset(split='test',processor=preprocess,tokenizer=tokenizer)
69
+
70
+ train_dl = DataLoader(train_ds, batch_size=batchSize, shuffle=True, num_workers=4, collate_fn=collate, pin_memory=True)
71
+ val_dl = DataLoader(val_ds, batch_size=batchSize, shuffle=False, num_workers=4, collate_fn=collate, pin_memory=True)
72
+
73
+ total_steps = epochs * math.ceil(len(train_dl) / grad_accum)
74
+ def lr_lambda(step):
75
+ if step < warmup_steps:
76
+ return (step + 1) / max(1, warmup_steps)
77
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
78
+ return 0.5 * (1 + math.cos(math.pi * progress))
79
+
80
+
81
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
82
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
83
+
84
+ scaler = torch.cuda.amp.GradScaler(enabled=(device.startswith("cuda")))
85
+ best_val = float("inf")
86
+
87
+ for epoch in range(1,epochs+1):
88
+ model.train()
89
+ running = 0.0
90
+ step = 0
91
+ pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}")
92
+ optimizer.zero_grad(set_to_none=True)
93
+ for images, tokens in pbar:
94
+ images = images.to(device, non_blocking=True)
95
+ tokens = tokens.to(device, non_blocking=True)
96
+
97
+ with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))):
98
+ image_features = model.encode_image(images)
99
+ text_features = model.encode_text(tokens)
100
+ loss = clip_loss(image_features, text_features, model.logit_scale)
101
+
102
+ scaler.scale(loss / grad_accum).backward()
103
+ step += 1
104
+ running += loss.item()
105
+ if step % grad_accum == 0:
106
+ scaler.step(optimizer); scaler.update()
107
+ optimizer.zero_grad(set_to_none=True)
108
+ scheduler.step()
109
+
110
+ pbar.set_postfix(loss=running / step, lr=optimizer.param_groups[0]["lr"])
111
+
112
+
113
+ model.eval()
114
+ with torch.no_grad():
115
+ val_losses = []
116
+ for images, tokens in tqdm(val_dl, leave=False, desc="Val"):
117
+ images = images.to(device); tokens = tokens.to(device)
118
+ with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))):
119
+ image_features = model.encode_image(images)
120
+ text_features = model.encode_text(tokens)
121
+ val_loss = clip_loss(image_features, text_features, model.logit_scale)
122
+ val_losses.append(val_loss.item())
123
+ val_mean = sum(val_losses)/len(val_losses)
124
+
125
+ ckpt_path = os.path.join(output_dir, f"epoch{epoch}_val{val_mean:.4f}.pt")
126
+ torch.save({"model": model.state_dict()}, ckpt_path)
127
+ if val_mean < best_val:
128
+ best_val = val_mean
129
+ torch.save({"model": model.state_dict()}, os.path.join(output_dir, "best.pt"))
130
+
131
+ print(f"Epoch {epoch} done. TrainLoss ~{running/step:.4f} ValLoss {val_mean:.4f}")
132
+
133
+ class FeedbackDataset(Dataset):
134
+ def __init__(self, examples, processor=None):
135
+ self.examples = examples
136
+ self.processor = processor
137
+
138
+ def __len__(self):
139
+ return len(self.examples)
140
+
141
+ def __getitem__(self, idx):
142
+ ex = self.examples[idx]
143
+ image = ex["image"]
144
+ if not isinstance(image, Image.Image):
145
+ image = Image.open(image).convert("RGB")
146
+ return image, ex["text"], ex["label"]
147
+
148
+ def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
149
+ dataset=FeedbackDataset(data,processor=processor)
150
+ dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
151
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
152
+ loss_fn = nn.CosineEmbeddingLoss()
153
+ model.load_state_dict(torch.load(SAVE_DIR, map_location=device))
154
+ model.train()
155
+ for epoch in range(epochs):
156
+ total_loss = 0
157
+ for images, texts, labels in dataLoader:
158
+ inputs = processor(text=texts, images=images,
159
+ return_tensors="pt", padding=True).to(device)
160
+ text_embeds = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])
161
+ image_embeds = model.get_image_features(inputs["pixel_values"])
162
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
163
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
164
+
165
+ labels = torch.tensor(labels, dtype=torch.float, device=device)
166
+ loss = loss_fn(image_embeds, text_embeds, labels)
167
+
168
+ optimizer.zero_grad()
169
+ loss.backward()
170
+ optimizer.step()
171
+
172
+ total_loss += loss.item()
173
+
174
+ print(f"{epoch+1}/{epochs} , Loss :{total_loss/len(dataLoader):.4f}")
175
+
176
+
177
+
178
+ def encode_img_and_text(imgs,text):
179
+ image_feat=[]
180
+ model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
181
+ checkpoint = torch.load(SAVE_DIR, map_location=device)
182
+ model.to(device)
183
+ for img in imgs:
184
+ if hasattr(img, 'read'):
185
+ image = Image.open(img.stream).convert("RGB")
186
+ else:
187
+ if isinstance(img, dict) and 'preview' in img:
188
+ img_data = img['preview'].split(",")[1]
189
+ image = Image.open(BytesIO(base64.b64decode(img_data))).convert("RGB")
190
+ else:
191
+ raise ValueError("Unsupported image input")
192
+ image_input = preprocess(image).unsqueeze(0).to(device)
193
+ with torch.no_grad():
194
+ image_features = model.encode_image(image_input)
195
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
196
+ image_feat.append(image_features)
197
+ image_embedding=torch.stack(image_feat).mean(dim=0)
198
+ text_tokens=tokenizer([text]).to(device)
199
+ with torch.no_grad():
200
+ text_features = model.encode_text(text_tokens)
201
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
202
+ alpha=0.7
203
+ combined=alpha*image_embedding+(1-alpha)*text_features
204
+ combined=combined/combined.norm(dim=-1,keepdim=True)
205
+ return combined.squeeze(0).cpu().tolist()
206
+