PrashantGoyal commited on
Commit
e608f06
·
1 Parent(s): 7357b32

model problem solved

Browse files
Files changed (4) hide show
  1. App/app.py +0 -2
  2. allfiles.txt +32 -0
  3. src/evaluation.py +11 -1
  4. src/training.py +12 -3
App/app.py CHANGED
@@ -5,7 +5,6 @@ from functools import wraps
5
  from supabase import create_client
6
  from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
7
  from App.models import User,LostItem,FoundItem,Match
8
- from dotenv import load_dotenv
9
  from flask_cors import CORS
10
  from src.training import encode_img_and_text
11
  from qdrant_client import QdrantClient
@@ -20,7 +19,6 @@ from datetime import timedelta
20
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
21
 
22
 
23
- load_dotenv()
24
 
25
  app = Flask(__name__)
26
  app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
 
5
  from supabase import create_client
6
  from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
7
  from App.models import User,LostItem,FoundItem,Match
 
8
  from flask_cors import CORS
9
  from src.training import encode_img_and_text
10
  from qdrant_client import QdrantClient
 
19
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
20
 
21
 
 
22
 
23
  app = Flask(__name__)
24
  app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
allfiles.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2fc8608ffc5292bfd2ea1bbf97b7e50c0ca797ac
2
+ 6cca4d3a19bbdbbd248b38fefcffb6e799115e4a
3
+ 8a58eaa680084c7828b7190d72a28a27657e40f5
4
+ a718cdd3df5c477a331c5658ee7243b473e28711
5
+ d24b09fbb595bc12d5edb593a773cfd84888c831
6
+ 26ecb4581bef71354473073c50f0114adee83454
7
+ 32212c33b8a43bacaafa60517b709500202868e0
8
+ 40dd086edb25acc272d44312232b432a65b3cf92
9
+ 4f7c84bb828dc03ace4e24a1d9b4a075b814f966
10
+ b542cfcc9811103dcb34600f304fef079586c745
11
+ 315969e673bb1c2280258455ee10fe654cef7e58 .env
12
+ a6344aac8c09253b3b630fb776ae94478aa0275b .gitattributes
13
+ 354e4823dd1ab296ac6f2a5109a80019bd3cffd6 .gitignore
14
+ de6c6c2d8cb15165b4e8f7ab22f58dfc40fd3f1f .gitignore
15
+ c4e737bc10b777de9ff0b3d6412c67f833e3e285 App
16
+ e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 App/__init__.py
17
+ 025c75b2d2b2df29802ed1f3b707e999fe15ac41 App/app.py
18
+ bcc7d77ffc45f7c1cce1ff2520e29b3f1196d2e3 App/models.py
19
+ 7108ef9623bce7e41f86ed1975f02e545bcf84ff App/scheduler.py
20
+ fecba207f6b5c465c85039cdc28640b82772b0d8 Dockerfile
21
+ f5824832d73309e87386f99ebca332120d0b6e4f model
22
+ b9aa3f250bd0cdb769d753e0847ac454b03b258b model/clip
23
+ 116a8bb85040164dcf8dfde96112b331857465c1 model/clip/best.pt
24
+ 04135bae2a3d5f95ac25a5f7117a822516b4a9ef model/clip/epoch5_val0.7777.pt
25
+ 7b064d8f8f7e759fc42f24447a2ba2e625f77a73 README.md
26
+ 2f808aa33de8a014b003f0c5430f217fa79ca876 requirements.txt
27
+ 5d2e66968630adbd91f3c44caeb34446b44f9360 requirements_scheduler.txt
28
+ 6ad0ff8380d2b4136fe12becdf1928fd301778b7 setup.py
29
+ edfc8a8148ea12a4982caf45fc2a3c3c7b84dfe5 src
30
+ 6bd3ca06c394c2095f295e6889d7f706005d40b5 src/evaluation.py
31
+ 8cd3fa1c487903893ff4c3f5c40495c0dc3f9c1d src/preprocessing.py
32
+ 25d7f73ad394a323774f913e50315fb5d3832a2b src/training.py
src/evaluation.py CHANGED
@@ -6,9 +6,19 @@ import open_clip
6
  from datasets import load_dataset
7
  from tqdm import tqdm
8
  import numpy as np
 
9
 
10
  import warnings
11
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
 
 
 
 
 
 
 
 
 
12
 
13
  def collate(batch):
14
  img,text=zip(*batch)
@@ -43,7 +53,7 @@ def main(path='./model/clip/best.pt',arch='ViT-B-32', pretrained='openai'):
43
  torch.cuda.empty_cache()
44
  model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
45
  tokenizer=open_clip.get_tokenizer(arch)
46
- state=torch.load(path,map_location='cuda')['model']
47
  model.load_state_dict(state, strict=False)
48
  model.eval()
49
  print('model loaded')
 
6
  from datasets import load_dataset
7
  from tqdm import tqdm
8
  import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
 
11
  import warnings
12
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
13
+ HF_TOKEN=os.getenv("HF_TOKEN")
14
+
15
+ MODEL_ID = "PrashantGoyal/findr-clip-ft"
16
+
17
+ model_path = hf_hub_download(
18
+ repo_id=MODEL_ID,
19
+ filename="best.pt",
20
+ token=os.getenv("HF_TOKEN")
21
+ )
22
 
23
  def collate(batch):
24
  img,text=zip(*batch)
 
53
  torch.cuda.empty_cache()
54
  model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
55
  tokenizer=open_clip.get_tokenizer(arch)
56
+ state=torch.load(model_path,map_location='cuda')['model']
57
  model.load_state_dict(state, strict=False)
58
  model.eval()
59
  print('model loaded')
src/training.py CHANGED
@@ -10,13 +10,22 @@ from src.preprocessing import Preprocessing
10
  from torch.utils.data import DataLoader,Dataset
11
  import warnings
12
  import base64
 
13
  from io import BytesIO
14
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
15
 
16
  device='cuda' if torch.cuda.is_available() else 'cpu'
17
  torch.cuda.empty_cache()
18
  model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
19
- SAVE_DIR='model/clip/best.pt'
 
 
 
 
 
 
 
 
20
  tokenizer=open_clip.get_tokenizer('ViT-B-32')
21
 
22
  def seed_everything(seed=42):
@@ -150,7 +159,7 @@ def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
150
  dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
151
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
152
  loss_fn = nn.CosineEmbeddingLoss()
153
- model.load_state_dict(torch.load(SAVE_DIR, map_location=device))
154
  model.train()
155
  for epoch in range(epochs):
156
  total_loss = 0
@@ -178,7 +187,7 @@ def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
178
  def encode_img_and_text(imgs,text):
179
  image_feat=[]
180
  model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
181
- checkpoint = torch.load(SAVE_DIR, map_location=device)
182
  model.to(device)
183
  for img in imgs:
184
  if hasattr(img, 'read'):
 
10
  from torch.utils.data import DataLoader,Dataset
11
  import warnings
12
  import base64
13
+ from huggingface_hub import hf_hub_download
14
  from io import BytesIO
15
  warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
16
 
17
  device='cuda' if torch.cuda.is_available() else 'cpu'
18
  torch.cuda.empty_cache()
19
  model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
20
+ HF_TOKEN=os.getenv("HF_TOKEN")
21
+
22
+ MODEL_ID = "PrashantGoyal/findr-clip-ft"
23
+
24
+ model_path = hf_hub_download(
25
+ repo_id=MODEL_ID,
26
+ filename="best.pt",
27
+ token=os.getenv("HF_TOKEN")
28
+ )
29
  tokenizer=open_clip.get_tokenizer('ViT-B-32')
30
 
31
  def seed_everything(seed=42):
 
159
  dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
160
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
161
  loss_fn = nn.CosineEmbeddingLoss()
162
+ model.load_state_dict(torch.load(model_path, map_location=device))
163
  model.train()
164
  for epoch in range(epochs):
165
  total_loss = 0
 
187
  def encode_img_and_text(imgs,text):
188
  image_feat=[]
189
  model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
190
+ checkpoint = torch.load(model_path, map_location=device)
191
  model.to(device)
192
  for img in imgs:
193
  if hasattr(img, 'read'):