Spaces:
Running
Running
Commit
·
e608f06
1
Parent(s):
7357b32
model problem solved
Browse files- App/app.py +0 -2
- allfiles.txt +32 -0
- src/evaluation.py +11 -1
- src/training.py +12 -3
App/app.py
CHANGED
|
@@ -5,7 +5,6 @@ from functools import wraps
|
|
| 5 |
from supabase import create_client
|
| 6 |
from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
|
| 7 |
from App.models import User,LostItem,FoundItem,Match
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
from flask_cors import CORS
|
| 10 |
from src.training import encode_img_and_text
|
| 11 |
from qdrant_client import QdrantClient
|
|
@@ -20,7 +19,6 @@ from datetime import timedelta
|
|
| 20 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 21 |
|
| 22 |
|
| 23 |
-
load_dotenv()
|
| 24 |
|
| 25 |
app = Flask(__name__)
|
| 26 |
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
|
|
|
|
| 5 |
from supabase import create_client
|
| 6 |
from flask_jwt_extended import JWTManager, create_access_token,unset_jwt_cookies, jwt_required, get_jwt_identity,decode_token
|
| 7 |
from App.models import User,LostItem,FoundItem,Match
|
|
|
|
| 8 |
from flask_cors import CORS
|
| 9 |
from src.training import encode_img_and_text
|
| 10 |
from qdrant_client import QdrantClient
|
|
|
|
| 19 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
app = Flask(__name__)
|
| 24 |
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv("DATABASE_URL")
|
allfiles.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2fc8608ffc5292bfd2ea1bbf97b7e50c0ca797ac
|
| 2 |
+
6cca4d3a19bbdbbd248b38fefcffb6e799115e4a
|
| 3 |
+
8a58eaa680084c7828b7190d72a28a27657e40f5
|
| 4 |
+
a718cdd3df5c477a331c5658ee7243b473e28711
|
| 5 |
+
d24b09fbb595bc12d5edb593a773cfd84888c831
|
| 6 |
+
26ecb4581bef71354473073c50f0114adee83454
|
| 7 |
+
32212c33b8a43bacaafa60517b709500202868e0
|
| 8 |
+
40dd086edb25acc272d44312232b432a65b3cf92
|
| 9 |
+
4f7c84bb828dc03ace4e24a1d9b4a075b814f966
|
| 10 |
+
b542cfcc9811103dcb34600f304fef079586c745
|
| 11 |
+
315969e673bb1c2280258455ee10fe654cef7e58 .env
|
| 12 |
+
a6344aac8c09253b3b630fb776ae94478aa0275b .gitattributes
|
| 13 |
+
354e4823dd1ab296ac6f2a5109a80019bd3cffd6 .gitignore
|
| 14 |
+
de6c6c2d8cb15165b4e8f7ab22f58dfc40fd3f1f .gitignore
|
| 15 |
+
c4e737bc10b777de9ff0b3d6412c67f833e3e285 App
|
| 16 |
+
e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 App/__init__.py
|
| 17 |
+
025c75b2d2b2df29802ed1f3b707e999fe15ac41 App/app.py
|
| 18 |
+
bcc7d77ffc45f7c1cce1ff2520e29b3f1196d2e3 App/models.py
|
| 19 |
+
7108ef9623bce7e41f86ed1975f02e545bcf84ff App/scheduler.py
|
| 20 |
+
fecba207f6b5c465c85039cdc28640b82772b0d8 Dockerfile
|
| 21 |
+
f5824832d73309e87386f99ebca332120d0b6e4f model
|
| 22 |
+
b9aa3f250bd0cdb769d753e0847ac454b03b258b model/clip
|
| 23 |
+
116a8bb85040164dcf8dfde96112b331857465c1 model/clip/best.pt
|
| 24 |
+
04135bae2a3d5f95ac25a5f7117a822516b4a9ef model/clip/epoch5_val0.7777.pt
|
| 25 |
+
7b064d8f8f7e759fc42f24447a2ba2e625f77a73 README.md
|
| 26 |
+
2f808aa33de8a014b003f0c5430f217fa79ca876 requirements.txt
|
| 27 |
+
5d2e66968630adbd91f3c44caeb34446b44f9360 requirements_scheduler.txt
|
| 28 |
+
6ad0ff8380d2b4136fe12becdf1928fd301778b7 setup.py
|
| 29 |
+
edfc8a8148ea12a4982caf45fc2a3c3c7b84dfe5 src
|
| 30 |
+
6bd3ca06c394c2095f295e6889d7f706005d40b5 src/evaluation.py
|
| 31 |
+
8cd3fa1c487903893ff4c3f5c40495c0dc3f9c1d src/preprocessing.py
|
| 32 |
+
25d7f73ad394a323774f913e50315fb5d3832a2b src/training.py
|
src/evaluation.py
CHANGED
|
@@ -6,9 +6,19 @@ import open_clip
|
|
| 6 |
from datasets import load_dataset
|
| 7 |
from tqdm import tqdm
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def collate(batch):
|
| 14 |
img,text=zip(*batch)
|
|
@@ -43,7 +53,7 @@ def main(path='./model/clip/best.pt',arch='ViT-B-32', pretrained='openai'):
|
|
| 43 |
torch.cuda.empty_cache()
|
| 44 |
model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
|
| 45 |
tokenizer=open_clip.get_tokenizer(arch)
|
| 46 |
-
state=torch.load(
|
| 47 |
model.load_state_dict(state, strict=False)
|
| 48 |
model.eval()
|
| 49 |
print('model loaded')
|
|
|
|
| 6 |
from datasets import load_dataset
|
| 7 |
from tqdm import tqdm
|
| 8 |
import numpy as np
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
|
| 11 |
import warnings
|
| 12 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 13 |
+
HF_TOKEN=os.getenv("HF_TOKEN")
|
| 14 |
+
|
| 15 |
+
MODEL_ID = "PrashantGoyal/findr-clip-ft"
|
| 16 |
+
|
| 17 |
+
model_path = hf_hub_download(
|
| 18 |
+
repo_id=MODEL_ID,
|
| 19 |
+
filename="best.pt",
|
| 20 |
+
token=os.getenv("HF_TOKEN")
|
| 21 |
+
)
|
| 22 |
|
| 23 |
def collate(batch):
|
| 24 |
img,text=zip(*batch)
|
|
|
|
| 53 |
torch.cuda.empty_cache()
|
| 54 |
model, _, preprocess =open_clip.create_model_and_transforms(arch,pretrained=pretrained,device=device,quick_gelu=True )
|
| 55 |
tokenizer=open_clip.get_tokenizer(arch)
|
| 56 |
+
state=torch.load(model_path,map_location='cuda')['model']
|
| 57 |
model.load_state_dict(state, strict=False)
|
| 58 |
model.eval()
|
| 59 |
print('model loaded')
|
src/training.py
CHANGED
|
@@ -10,13 +10,22 @@ from src.preprocessing import Preprocessing
|
|
| 10 |
from torch.utils.data import DataLoader,Dataset
|
| 11 |
import warnings
|
| 12 |
import base64
|
|
|
|
| 13 |
from io import BytesIO
|
| 14 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 15 |
|
| 16 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
torch.cuda.empty_cache()
|
| 18 |
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
tokenizer=open_clip.get_tokenizer('ViT-B-32')
|
| 21 |
|
| 22 |
def seed_everything(seed=42):
|
|
@@ -150,7 +159,7 @@ def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
|
|
| 150 |
dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
|
| 151 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 152 |
loss_fn = nn.CosineEmbeddingLoss()
|
| 153 |
-
model.load_state_dict(torch.load(
|
| 154 |
model.train()
|
| 155 |
for epoch in range(epochs):
|
| 156 |
total_loss = 0
|
|
@@ -178,7 +187,7 @@ def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6):
|
|
| 178 |
def encode_img_and_text(imgs,text):
|
| 179 |
image_feat=[]
|
| 180 |
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
|
| 181 |
-
checkpoint = torch.load(
|
| 182 |
model.to(device)
|
| 183 |
for img in imgs:
|
| 184 |
if hasattr(img, 'read'):
|
|
|
|
| 10 |
from torch.utils.data import DataLoader,Dataset
|
| 11 |
import warnings
|
| 12 |
import base64
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
from io import BytesIO
|
| 15 |
warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*")
|
| 16 |
|
| 17 |
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
torch.cuda.empty_cache()
|
| 19 |
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device )
|
| 20 |
+
HF_TOKEN=os.getenv("HF_TOKEN")
|
| 21 |
+
|
| 22 |
+
MODEL_ID = "PrashantGoyal/findr-clip-ft"
|
| 23 |
+
|
| 24 |
+
model_path = hf_hub_download(
|
| 25 |
+
repo_id=MODEL_ID,
|
| 26 |
+
filename="best.pt",
|
| 27 |
+
token=os.getenv("HF_TOKEN")
|
| 28 |
+
)
|
| 29 |
tokenizer=open_clip.get_tokenizer('ViT-B-32')
|
| 30 |
|
| 31 |
def seed_everything(seed=42):
|
|
|
|
| 159 |
dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
|
| 160 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 161 |
loss_fn = nn.CosineEmbeddingLoss()
|
| 162 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 163 |
model.train()
|
| 164 |
for epoch in range(epochs):
|
| 165 |
total_loss = 0
|
|
|
|
| 187 |
def encode_img_and_text(imgs,text):
|
| 188 |
image_feat=[]
|
| 189 |
model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True )
|
| 190 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 191 |
model.to(device)
|
| 192 |
for img in imgs:
|
| 193 |
if hasattr(img, 'read'):
|