Spaces:
Sleeping
Sleeping
| import os,random,math | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import open_clip | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from src.preprocessing import Preprocessing | |
| from torch.utils.data import DataLoader,Dataset | |
| import warnings | |
| import base64 | |
| from huggingface_hub import hf_hub_download | |
| from io import BytesIO | |
| warnings.filterwarnings("ignore", message=".*QuickGELU mismatch.*") | |
| device='cuda' if torch.cuda.is_available() else 'cpu' | |
| torch.cuda.empty_cache() | |
| model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device ) | |
| HF_TOKEN=os.getenv("HF_TOKEN") | |
| MODEL_ID = "PrashantGoyal/findr-clip-ft" | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_ID, | |
| force_download=True, | |
| filename="clip/best.pt", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| # model_path = "model/clip/best.pt" | |
| tokenizer=open_clip.get_tokenizer('ViT-B-32') | |
| def seed_everything(seed=42): | |
| random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) | |
| class clip_dataset(torch.utils.data.Dataset): | |
| def __init__(self,split='val',processor=None,tokenizer=None): | |
| preprocessor=Preprocessing() | |
| self.ds=preprocessor.load_dataset(split=split) | |
| self.tokenizer=tokenizer | |
| self.processor=processor | |
| def __len__(self): | |
| return len(self.ds) | |
| def __getitem__(self,index): | |
| data=self.ds[index] | |
| img:Image.Image=data['image'].convert('RGB') | |
| text=random.choice(data['answer']).strip() | |
| image=self.processor(img) if self.processor else img | |
| token_text=self.tokenizer([text])[0] | |
| return image,token_text | |
| def clip_loss(image_features, text_features, temperature): | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| logits_per_image = (image_features @ text_features.t()) * torch.exp(temperature) | |
| logits_per_text = logits_per_image.t() | |
| targets = torch.arange(image_features.size(0), device=image_features.device) | |
| loss_i = nn.CrossEntropyLoss()(logits_per_image, targets) | |
| loss_t = nn.CrossEntropyLoss()(logits_per_text, targets) | |
| return (loss_i + loss_t) / 2 | |
| def collate(batch): | |
| imgs, toks = zip(*batch) | |
| imgs = torch.stack(imgs, 0) | |
| toks = torch.stack(toks, 0) | |
| return imgs, toks | |
| def train(arch='ViT-B-32',pretrained='openai',batchSize=2,epochs=5,lr=5e-5,warmup_steps=200,grad_accum=1,output_dir='model/clip'): | |
| seed_everything(42) | |
| torch.cuda.empty_cache() | |
| os.makedirs(output_dir,exist_ok=True) | |
| tokenizer=open_clip.get_tokenizer(arch) | |
| train_ds=clip_dataset(split='val',processor=preprocess,tokenizer=tokenizer) | |
| val_ds=clip_dataset(split='test',processor=preprocess,tokenizer=tokenizer) | |
| train_dl = DataLoader(train_ds, batch_size=batchSize, shuffle=True, num_workers=4, collate_fn=collate, pin_memory=True) | |
| val_dl = DataLoader(val_ds, batch_size=batchSize, shuffle=False, num_workers=4, collate_fn=collate, pin_memory=True) | |
| total_steps = epochs * math.ceil(len(train_dl) / grad_accum) | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return (step + 1) / max(1, warmup_steps) | |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| return 0.5 * (1 + math.cos(math.pi * progress)) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
| scaler = torch.cuda.amp.GradScaler(enabled=(device.startswith("cuda"))) | |
| best_val = float("inf") | |
| for epoch in range(1,epochs+1): | |
| model.train() | |
| running = 0.0 | |
| step = 0 | |
| pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{epochs}") | |
| optimizer.zero_grad(set_to_none=True) | |
| for images, tokens in pbar: | |
| images = images.to(device, non_blocking=True) | |
| tokens = tokens.to(device, non_blocking=True) | |
| with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))): | |
| image_features = model.encode_image(images) | |
| text_features = model.encode_text(tokens) | |
| loss = clip_loss(image_features, text_features, model.logit_scale) | |
| scaler.scale(loss / grad_accum).backward() | |
| step += 1 | |
| running += loss.item() | |
| if step % grad_accum == 0: | |
| scaler.step(optimizer); scaler.update() | |
| optimizer.zero_grad(set_to_none=True) | |
| scheduler.step() | |
| pbar.set_postfix(loss=running / step, lr=optimizer.param_groups[0]["lr"]) | |
| model.eval() | |
| with torch.no_grad(): | |
| val_losses = [] | |
| for images, tokens in tqdm(val_dl, leave=False, desc="Val"): | |
| images = images.to(device); tokens = tokens.to(device) | |
| with torch.cuda.amp.autocast(enabled=(device.startswith("cuda"))): | |
| image_features = model.encode_image(images) | |
| text_features = model.encode_text(tokens) | |
| val_loss = clip_loss(image_features, text_features, model.logit_scale) | |
| val_losses.append(val_loss.item()) | |
| val_mean = sum(val_losses)/len(val_losses) | |
| ckpt_path = os.path.join(output_dir, f"epoch{epoch}_val{val_mean:.4f}.pt") | |
| torch.save({"model": model.state_dict()}, ckpt_path) | |
| if val_mean < best_val: | |
| best_val = val_mean | |
| torch.save({"model": model.state_dict()}, os.path.join(output_dir, "best.pt")) | |
| print(f"Epoch {epoch} done. TrainLoss ~{running/step:.4f} ValLoss {val_mean:.4f}") | |
| class FeedbackDataset(Dataset): | |
| def __init__(self, examples, processor=None): | |
| self.examples = examples | |
| self.processor = processor | |
| def __len__(self): | |
| return len(self.examples) | |
| def __getitem__(self, idx): | |
| ex = self.examples[idx] | |
| image = ex["image"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image).convert("RGB") | |
| return image, ex["text"], ex["label"] | |
| def feedback(model,processor,device,data,epochs=5,batch_size=4,lr=1e-6): | |
| dataset=FeedbackDataset(data,processor=processor) | |
| dataLoader=DataLoader(dataset,batch_size=batch_size,shuffle=True) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
| loss_fn = nn.CosineEmbeddingLoss() | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.train() | |
| for epoch in range(epochs): | |
| total_loss = 0 | |
| for images, texts, labels in dataLoader: | |
| inputs = processor(text=texts, images=images, | |
| return_tensors="pt", padding=True).to(device) | |
| text_embeds = model.get_text_features(inputs["input_ids"], inputs["attention_mask"]) | |
| image_embeds = model.get_image_features(inputs["pixel_values"]) | |
| text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) | |
| image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) | |
| labels = torch.tensor(labels, dtype=torch.float, device=device) | |
| loss = loss_fn(image_embeds, text_embeds, labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"{epoch+1}/{epochs} , Loss :{total_loss/len(dataLoader):.4f}") | |
| def encode_img_and_text(imgs,text): | |
| image_feat=[] | |
| model, _, preprocess =open_clip.create_model_and_transforms('ViT-B-32',pretrained='openai',device=device,quick_gelu=True ) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.to(device) | |
| for img in imgs: | |
| if hasattr(img, 'read'): | |
| image = Image.open(img.stream).convert("RGB") | |
| else: | |
| if isinstance(img, dict) and 'preview' in img: | |
| img_data = img['preview'].split(",")[1] | |
| image = Image.open(BytesIO(base64.b64decode(img_data))).convert("RGB") | |
| else: | |
| raise ValueError("Unsupported image input") | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_input) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| image_feat.append(image_features) | |
| image_embedding=torch.stack(image_feat).mean(dim=0) | |
| text_tokens=tokenizer([text]).to(device) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| alpha=0.7 | |
| combined=alpha*image_embedding+(1-alpha)*text_features | |
| combined=combined/combined.norm(dim=-1,keepdim=True) | |
| return combined.squeeze(0).cpu().tolist() | |