Sanj12 commited on
Commit
5e90518
·
verified ·
1 Parent(s): c80c289

Upload 14 files

Browse files
src/.env ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ CLOUD_NAME=dmiu2ccxh
2
+ API_KEY=681379735874788
3
+ API_SECRET=yyEZbsOBe8j9XsBWoYsA2qpHu_I
src/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ to see tables:
2
+
3
+ sqlite3
4
+
5
+ - pip install tabulate
6
+ - python db_test.py
src/build_gallery_embeddings.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # build_gallery_embeddings.py
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+ import torch
5
+ import os
6
+ import pickle
7
+
8
+ gallery_dir = "gallery"
9
+ embedding_file = "gallery_embeddings.pkl"
10
+
11
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
+
14
+ embeddings = []
15
+
16
+ for fname in os.listdir(gallery_dir):
17
+ if fname.endswith(('.jpg', '.jpeg', '.png')):
18
+ img_path = os.path.join(gallery_dir, fname)
19
+ image = Image.open(img_path).convert("RGB")
20
+
21
+ inputs = processor(images=image, return_tensors="pt")
22
+ with torch.no_grad():
23
+ image_emb = model.get_image_features(**inputs)
24
+ image_emb = image_emb / image_emb.norm(p=2, dim=-1) # normalize
25
+
26
+ embeddings.append({
27
+ "filename": fname,
28
+ "embedding": image_emb.squeeze().cpu()
29
+ })
30
+
31
+ # Save embeddings
32
+ with open(embedding_file, "wb") as f:
33
+ pickle.dump(embeddings, f)
34
+
35
+ print(f"Saved {len(embeddings)} image embeddings.")
src/cloudinary_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import cloudinary
4
+ import cloudinary.uploader
5
+
6
+ load_dotenv()
7
+
8
+ cloudinary.config(
9
+ cloud_name=os.getenv("CLOUD_NAME"),
10
+ api_key=os.getenv("API_KEY"),
11
+ api_secret=os.getenv("API_SECRET"),
12
+ )
13
+
14
+ def upload_to_cloudinary(filepath):
15
+ response = cloudinary.uploader.upload(filepath)
16
+ return response.get("secure_url")
src/curato.db ADDED
Binary file (12.3 kB). View file
 
src/curato_api.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify
2
+ import sqlite3
3
+
4
+ app = Flask(__name__)
5
+
6
+ DATABASE = r'C:\Users\sanjana\Desktop\curato\curato.db'
7
+
8
+ def get_db_connection():
9
+ conn = sqlite3.connect(DATABASE)
10
+ conn.row_factory = sqlite3.Row # To access columns by name
11
+ return conn
12
+
13
+ @app.route('/artworks', methods=['GET'])
14
+ def get_artworks():
15
+ conn = get_db_connection()
16
+ cursor = conn.execute('SELECT * FROM artworks') # replace 'artworks' with your table name
17
+ rows = cursor.fetchall()
18
+ conn.close()
19
+
20
+ artworks = [dict(row) for row in rows]
21
+ return jsonify(artworks)
22
+
23
+ if __name__ == '__main__':
24
+ app.run(debug=True)
src/db_sqlite.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ def init_db():
4
+ conn = sqlite3.connect("curato.db")
5
+ cursor = conn.cursor()
6
+
7
+ # Create table if it doesn't exist
8
+ cursor.execute("""
9
+ CREATE TABLE IF NOT EXISTS artworks (
10
+ filename TEXT PRIMARY KEY,
11
+ style TEXT,
12
+ tags TEXT,
13
+ caption TEXT
14
+ )
15
+ """)
16
+
17
+ # Add cloud_url column if not exists
18
+ cursor.execute("PRAGMA table_info(artworks)")
19
+ columns = [col[1] for col in cursor.fetchall()]
20
+ if 'cloud_url' not in columns:
21
+ cursor.execute("ALTER TABLE artworks ADD COLUMN cloud_url TEXT")
22
+
23
+ conn.commit()
24
+ conn.close()
25
+
26
+ def save_metadata(filename, style, tags, caption, cloud_url):
27
+ conn = sqlite3.connect("curato.db")
28
+ cursor = conn.cursor()
29
+ cursor.execute("""
30
+ INSERT OR REPLACE INTO artworks (filename, style, tags, caption, cloud_url)
31
+ VALUES (?, ?, ?, ?, ?)
32
+ """, (filename, style, ",".join(tags), caption, cloud_url))
33
+ conn.commit()
34
+ conn.close()
src/db_test.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from tabulate import tabulate # for pretty table output, optional
3
+
4
+ # Connect to your SQLite DB (change the path accordingly)
5
+ conn = sqlite3.connect(r'C:\Users\sanjana\Desktop\curato\curato.db')
6
+ cursor = conn.cursor()
7
+
8
+ # 1. List all tables
9
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
10
+ tables = cursor.fetchall()
11
+ print("Tables in the database:")
12
+ for table_name in tables:
13
+ print("-", table_name[0])
14
+
15
+ # Replace with your actual table name
16
+ table_to_show = tables[0][0] # Just pick the first table for demo
17
+
18
+ print(f"\nShowing schema for table '{table_to_show}':")
19
+ cursor.execute(f"PRAGMA table_info({table_to_show})")
20
+ columns = cursor.fetchall()
21
+ for col in columns:
22
+ print(f"Column: {col[1]}, Type: {col[2]}")
23
+
24
+ print(f"\nAll data from table '{table_to_show}':")
25
+ cursor.execute(f"SELECT * FROM {table_to_show}")
26
+ rows = cursor.fetchall()
27
+
28
+ # Print rows in a nice table format (requires tabulate)
29
+ try:
30
+ print(tabulate(rows, headers=[col[1] for col in columns], tablefmt="grid"))
31
+ except ImportError:
32
+ # If tabulate is not installed, print raw rows and headers
33
+ print([col[1] for col in columns])
34
+ for row in rows:
35
+ print(row)
36
+
37
+ conn.close()
src/gallery_embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:193936530b98c1d7cd49ee06cea454b7ec472b360da56e165bd388f0b0e3f3f6
3
+ size 40099
src/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ Pillow
5
+ numpy
6
+ scikit-learn
7
+ streamlit
8
+ cloudinary
9
+
src/search.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # search.py
2
+ import torch
3
+ import pickle
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from PIL import Image
6
+
7
+ # Load model once
8
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
+
11
+ # Load saved gallery
12
+ with open("gallery_embeddings.pkl", "rb") as f:
13
+ GALLERY = pickle.load(f)
14
+
15
+ def find_similar_images(query_image_path, top_k=5):
16
+ image = Image.open(query_image_path).convert("RGB")
17
+ inputs = processor(images=image, return_tensors="pt")
18
+ with torch.no_grad():
19
+ query_emb = model.get_image_features(**inputs)
20
+ query_emb = query_emb / query_emb.norm(p=2, dim=-1)
21
+
22
+ similarities = []
23
+ for item in GALLERY:
24
+ gallery_emb = item["embedding"]
25
+ score = torch.nn.functional.cosine_similarity(query_emb, gallery_emb.unsqueeze(0)).item()
26
+ similarities.append((item["filename"], score))
27
+
28
+ top_matches = sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
29
+ return top_matches
src/style_classifier.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torch.utils.data import DataLoader
5
+ from torchvision.datasets import ImageFolder
6
+ from torchvision import models
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+
10
+
11
+
12
+
13
+ def load_model(model_path="models/style_model.pth", class_names=[]):
14
+ import torch
15
+ from torchvision import models
16
+ model = models.resnet18(pretrained=False)
17
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
18
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
19
+ model.eval()
20
+ return model
21
+
22
+ def predict_style(image_path, model, class_names):
23
+ from PIL import Image
24
+ from torchvision import transforms
25
+ import torch
26
+
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor()
30
+ ])
31
+ image = Image.open(image_path).convert("RGB")
32
+ image = transform(image).unsqueeze(0)
33
+
34
+ with torch.no_grad():
35
+ output = model(image)
36
+ _, predicted = torch.max(output, 1)
37
+
38
+ return class_names[predicted.item()]
src/tagger.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ from PIL import Image
3
+ import torch
4
+
5
+ # Load model + processor
6
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
7
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
8
+
9
+ # Candidate tags
10
+ CANDIDATE_TAGS = [
11
+ "portrait", "landscape", "abstract", "surreal", "dark", "bright",
12
+ "melancholy", "joyful", "blue tones", "warm colors", "minimalist", "detailed"
13
+ ]
14
+
15
+ def generate_tags(image_path):
16
+ image = Image.open(image_path).convert("RGB")
17
+ inputs = processor(text=CANDIDATE_TAGS, images=image, return_tensors="pt", padding=True)
18
+ outputs = model(**inputs)
19
+
20
+ logits_per_image = outputs.logits_per_image
21
+ probs = logits_per_image.softmax(dim=1)
22
+
23
+ top_probs, indices = probs.topk(5)
24
+ tags = [CANDIDATE_TAGS[i] for i in indices[0]]
25
+
26
+ return tags
27
+
28
+ def generate_caption(image_path):
29
+ # Placeholder caption - replace this with real captioning logic
30
+ return "This is a placeholder caption."
src/train_style_classifier_hf.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from torchvision import transforms
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+
6
+ # Define transform
7
+ transform = transforms.Compose([
8
+ transforms.Resize((224, 224)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
11
+ std=[0.229, 0.224, 0.225])
12
+ ])
13
+
14
+ # Load HF dataset (use 'test' split because 'train' doesn't exist)
15
+ hf_dataset = load_dataset("asahi417/wikiart-all", split="test")
16
+
17
+ # Your custom Dataset class
18
+ class WikiArtDataset(Dataset):
19
+ def __init__(self, hf_dataset, transform=None):
20
+ self.dataset = hf_dataset
21
+ self.transform = transform
22
+
23
+ def __len__(self):
24
+ return len(self.dataset)
25
+
26
+ def __getitem__(self, idx):
27
+ image = self.dataset[idx]["image"]
28
+ label = self.dataset[idx]["style"]
29
+
30
+ if self.transform:
31
+ image = self.transform(image)
32
+
33
+ return image, label
34
+
35
+ # Create PyTorch dataset
36
+ dataset = WikiArtDataset(hf_dataset, transform=transform)
37
+
38
+ '''from torchvision import datasets, transforms, models
39
+ from torch.utils.data import DataLoader
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.optim as optim
43
+ from tqdm import tqdm
44
+ import os
45
+ import json
46
+
47
+ # Styles (should match folder names in data/wikiart_hf_small)
48
+ STYLES = ["Realism", "Cubism", "Impressionism", "Abstract Art"]
49
+ DATA_DIR = "data/wikiart"
50
+
51
+ # Define transforms
52
+ transform = transforms.Compose([
53
+ transforms.Resize((224, 224)),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
56
+ std=[0.229, 0.224, 0.225])
57
+ ])
58
+
59
+ # Load dataset from folders
60
+ dataset = datasets.ImageFolder(root="data/wikiart", transform=transform)
61
+
62
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
63
+
64
+ # Device
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+ # Load pretrained model
68
+ model = models.resnet18(pretrained=True)
69
+ model.fc = nn.Linear(model.fc.in_features, len(STYLES))
70
+ model.to(device)
71
+
72
+ # Loss & Optimizer
73
+ criterion = nn.CrossEntropyLoss()
74
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
75
+
76
+ # Train
77
+ model.train()
78
+ for epoch in range(20):
79
+ running_loss = 0.0
80
+ for images, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
81
+ images = images.to(device)
82
+ labels = labels.to(device)
83
+
84
+ outputs = model(images)
85
+ loss = criterion(outputs, labels)
86
+
87
+ optimizer.zero_grad()
88
+ loss.backward()
89
+ optimizer.step()
90
+
91
+ running_loss += loss.item()
92
+ print(f"Epoch {epoch+1} Loss: {running_loss:.4f}")
93
+
94
+ # Save model and label map
95
+ os.makedirs("models", exist_ok=True)
96
+ torch.save(model.state_dict(), "models/style_model_hf.pth")
97
+
98
+ # Save class names
99
+ with open("models/style_classes.json", "w") as f:
100
+ json.dump(dataset.classes, f)
101
+
102
+ print("✅ Model saved to models/style_model_hf.pth")
103
+ print("✅ Style classes saved to models/style_classes.json")
104
+ '''