samsonleegh commited on
Commit
cb76974
·
verified ·
1 Parent(s): 87cc493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py CHANGED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import hdbscan
4
+ import numpy as np
5
+ import requests
6
+ import os
7
+ import uuid
8
+ import ollama
9
+
10
+ from sklearn.cluster import KMeans
11
+ from sentence_transformers import SentenceTransformer, util
12
+ from huggingface_hub import login
13
+ from torch.quantization import quantize_dynamic
14
+ from umap import UMAP
15
+ from sklearn.metrics import silhouette_score
16
+
17
+ login("HF_API_KEY")
18
+ model_st = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
19
+ TMP_DIR = "./tmp_images"
20
+ os.makedirs(TMP_DIR, exist_ok=True)
21
+
22
+ def parse_with_ollama(text, llm_selector):
23
+ response = ollama.chat(
24
+ model=llm_selector, #'qwen2.5:3b', 'llama3.2:latest',
25
+ messages=[
26
+ {"role": "system", "content": "You are an image caption analyser for the trust and safety department. Based on the following image captions, provide an overall summary of these captions in less than 10 words."},
27
+ {"role": "user", "content": text}
28
+ ]
29
+ )
30
+ return response['message']['content']
31
+
32
+ def download_image(url, cluster_id, idx):
33
+ try:
34
+ response = requests.get(url, timeout=5)
35
+ if response.status_code == 200 and response.headers['Content-Type'].startswith('image'):
36
+ ext = response.headers['Content-Type'].split('/')[-1]
37
+ filename = f"cluster_{cluster_id}_{idx}_{uuid.uuid4().hex[:8]}.{ext}"
38
+ filepath = os.path.join(TMP_DIR, filename)
39
+ with open(filepath, 'wb') as f:
40
+ f.write(response.content)
41
+ return filepath
42
+ except Exception as e:
43
+ print(f"Failed to fetch image from {url}: {e}")
44
+ return None
45
+
46
+ def cluster_data(file, algorithm, umap_dims, llm_selector):
47
+ logs = [] # collect logs here
48
+
49
+ def log(msg):
50
+ logs.append(msg)
51
+ return "\n".join(logs)
52
+ try:
53
+ # Load CSV
54
+ df = pd.read_csv(file.name)
55
+
56
+ if 'top_tags' not in df.columns or 'img_url' not in df.columns:
57
+ return "Required columns ('top_tags', 'img_url') not found.", None
58
+
59
+ # Clean top_tags
60
+ text_ls = df['top_tags'].str.replace(r"[\[\]']", '', regex=True).to_list()
61
+
62
+ # Encode + UMAP
63
+ yield None, None, None, None, log("✅ Converting top_tags to embeddings...")
64
+ embeddings = model_st.encode(text_ls, batch_size=64, show_progress_bar=True)
65
+ yield None, None, None, None, log("✅ Reducing dimensions with UMAP " + str(umap_dims) + " dimensions...")
66
+ umap_model = UMAP(n_components=int(umap_dims), metric='cosine', random_state=42)
67
+ umap_embeddings = umap_model.fit_transform(embeddings)
68
+
69
+ # Cluster
70
+ yield None, None, None, None, log(f"✅ Clustering with {algorithm}...")
71
+ if algorithm == "KMeans":
72
+ N_CLUSTERS = max(2, round(np.sqrt(len(df))))
73
+ model = KMeans(n_clusters=N_CLUSTERS, random_state=0)
74
+ labels = model.fit_predict(umap_embeddings)
75
+ elif algorithm == "HDBSCAN":
76
+ # model = hdbscan.HDBSCAN(min_cluster_size=10)
77
+ # labels = model.fit_predict(umap_embeddings)
78
+ # Run HDBSCAN on the reduced space
79
+ hdb = hdbscan.HDBSCAN(
80
+ # min_cluster_size=30,
81
+ # min_samples=3,
82
+ # metric='euclidean', # Use Euclidean after UMAP
83
+ # cluster_selection_method='leaf'
84
+ )
85
+ hdb_labels = hdb.fit_predict(umap_embeddings)
86
+
87
+ labels = hdb.labels_
88
+ n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
89
+ n_noise = list(labels).count(-1)
90
+
91
+ print(f"Clusters found: {n_clusters}")
92
+ print(f"Noise samples: {n_noise} / {len(labels)} ({n_noise/len(labels)*100:.2f}%)")
93
+
94
+ noise_mask = hdb.labels_ == -1
95
+ noise_embeddings = umap_embeddings[noise_mask]
96
+
97
+ hdb_noise = hdbscan.HDBSCAN(
98
+ # metric='euclidean',
99
+ # min_cluster_size=10,
100
+ # min_samples=2,
101
+ # cluster_selection_method='leaf'
102
+ )
103
+ noise_labels = hdb_noise.fit_predict(noise_embeddings)
104
+
105
+ # Initialize full label array with original
106
+ labels = hdb.labels_.copy()
107
+
108
+ # Offset noise cluster labels to avoid collision with original ones
109
+ new_cluster_start = labels.max() + 1
110
+ relabelled_noise = np.where(noise_labels != -1, noise_labels + new_cluster_start, -1)
111
+
112
+ # Insert reclustered labels back into noise positions
113
+ labels[noise_mask] = relabelled_noise
114
+ else:
115
+ return "Unknown algorithm", None
116
+
117
+ cluster_silhouette_score = silhouette_score(umap_embeddings, labels, metric='euclidean') # use euclidean after UMAP reduction, else cosine better for text embeddings
118
+ silhouette_text = (
119
+ f"Silhouette Score: {cluster_silhouette_score:.3f}"
120
+ # "Explanation:\n"
121
+ # "Scores close to +1 indicate well-separated, compact clusters.\n"
122
+ # "Scores near 0 indicate overlapping clusters.\n"
123
+ # "Negative scores suggest possible misclassification."
124
+ )
125
+ # Label the df
126
+ df["cluster"] = labels
127
+
128
+ # Sample 5 images per cluster
129
+ # img_clusters = []
130
+ # for cluster_id in sorted(df['cluster'].unique()):
131
+ # sample_urls = df[df['cluster'] == cluster_id]['img_url'].dropna().unique()[:5]
132
+ # for url in sample_urls:
133
+ # img_clusters.append((f"Cluster {cluster_id}", url))
134
+ df = df[df["cluster"]!=-1]
135
+
136
+ img_clusters = []
137
+ yield None, None, None, None, log("✅ Downloading images...")
138
+ for cluster_id in sorted(df['cluster'].unique()):
139
+ urls = df[df['cluster'] == cluster_id]['img_url'].dropna().unique()[:5]
140
+ for idx, url in enumerate(urls):
141
+ img_path = download_image(url, cluster_id, idx)
142
+ if img_path:
143
+ img_clusters.append((os.path.abspath(img_path), f"Cluster {cluster_id}"))
144
+ prev_img_path = img_path
145
+ prev_cluster_id = cluster_id
146
+ else:
147
+ img_clusters.append((os.path.abspath(prev_img_path), f"Cluster {prev_cluster_id}"))
148
+
149
+ file_path = "cluster_output.csv"
150
+ df[['img_url','top_tags','cluster']].to_csv(file_path, index=False)
151
+ agg_df = df.groupby('cluster').agg(
152
+ top_tags_joined=('top_tags', lambda x: ', '.join(x)),
153
+ num_samples=('top_tags', 'count')
154
+ ).reset_index()
155
+ yield None, None, None, None, log("✅ Summarising cluster image tags with LLM...")
156
+ agg_df['tag_summary'] = agg_df['top_tags_joined'].apply(lambda x : parse_with_ollama(x, llm_selector))
157
+ agg_df = agg_df[['cluster','num_samples','tag_summary','top_tags_joined']]
158
+ yield agg_df, img_clusters, silhouette_text, file_path, log("✅ All done!")
159
+
160
+ except Exception as e:
161
+ return f"Error: {str(e)}", None, None, None, log(f"❌ Error: {str(e)}")
162
+
163
+ with gr.Blocks() as demo:
164
+ with gr.Row():
165
+
166
+ with gr.Column():
167
+ start_button = gr.Button("Start Clustering")
168
+ file_input = gr.File(file_types=[".csv"], label="Upload CSV")
169
+
170
+ with gr.Column():
171
+ algo_selector = gr.Dropdown(choices=["KMeans", "HDBSCAN"], label="Clustering Algorithm")
172
+ umap_dims = gr.Slider(minimum=2, maximum=100, value=20, step=1, label="UMAP Dimensions")
173
+ llm_selector = gr.Dropdown(choices=["qwen2.5:3b", "llama3.2:latest"], value="qwen2.5:3b", label="LLM Model")
174
+
175
+ download_filepath = gr.File(label="Download Clustered Output", type="filepath")
176
+
177
+ with gr.Row():
178
+ silhouette_text = gr.Textbox(label="Silhouette Score compares the average distance to points in the same cluster vs. points in the nearest other cluster. +1 indicate well-separated, compact clusters; 0 indicate overlapping clusters.", lines=1, interactive=False)
179
+
180
+ with gr.Row():
181
+ output_df = gr.Dataframe(label="Clustered Output", interactive=False)
182
+
183
+ with gr.Row():
184
+ gallery = gr.Gallery(label="Clustered Images (5 per cluster)", columns=5, height="auto")
185
+
186
+ with gr.Row():
187
+ log_box = gr.Textbox(label="Processing Logs", lines=10, interactive=False)
188
+
189
+ # Button triggers clustering
190
+ start_button.click(fn=cluster_data, inputs=[file_input, algo_selector, umap_dims, llm_selector], outputs=[output_df, gallery, silhouette_text, download_filepath, log_box])
191
+
192
+ demo.launch()