Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| from PIL import Image | |
| import torch | |
| import clip | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| import tempfile | |
| # Global variables for CLIP model (load once, reuse) | |
| _clip_model = None | |
| _clip_preprocess = None | |
| _device = None | |
| def get_clip_model(): | |
| global _clip_model, _clip_preprocess, _device | |
| if _clip_model is None: | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _clip_model, _clip_preprocess = clip.load("ViT-B/32", device=_device) | |
| return _clip_model, _clip_preprocess, _device | |
| class SmartCLIPClassifierNextCloudShare: | |
| def __init__(self, share_url, share_password, progress_callback=None): | |
| self.share_url = share_url.rstrip('/') | |
| self.share_password = share_password | |
| self.progress_callback = progress_callback | |
| self.session = requests.Session() | |
| self.session.auth = (self.get_share_token(), share_password) | |
| self.temp_dir = tempfile.mkdtemp() | |
| self.categories = [ | |
| "1_Booth", | |
| "2_Business_Interaction", | |
| "3_Buyer_Delegation", | |
| "4_Aisle", | |
| "5_Conference", | |
| "6_Fairground", | |
| "7_Products", | |
| "8_Registration", | |
| "9_Miscellaneous" | |
| ] | |
| self.log("Loading CLIP model...") | |
| self.model, self.preprocess, self.device = get_clip_model() | |
| self.log(f"β CLIP loaded on {self.device}") | |
| self.load_deep_analysis() | |
| self.log("π Scanning NextCloud share...") | |
| self.all_files = self.list_files("") | |
| self.log(f"Found {len(self.all_files)} total files") | |
| self.get_image_list() | |
| def log(self, message): | |
| if self.progress_callback: | |
| self.progress_callback(message) | |
| print(message) | |
| def get_share_token(self): | |
| return self.share_url.split('/s/')[-1] | |
| def get_webdav_url(self, path=""): | |
| token = self.get_share_token() | |
| base = self.share_url.rsplit('/s/', 1)[0] | |
| if path: | |
| return f"{base}/public.php/webdav/{path}" | |
| return f"{base}/public.php/webdav/" | |
| def download_file(self, filename): | |
| url = self.get_webdav_url(filename) | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.session.get(url, timeout=60) | |
| response.raise_for_status() | |
| return response.content | |
| except requests.exceptions.Timeout: | |
| if attempt == max_retries - 1: | |
| raise | |
| self.log(f"Timeout on attempt {attempt + 1}, retrying...") | |
| continue | |
| def upload_file(self, local_path, remote_filename): | |
| url = self.get_webdav_url(remote_filename) | |
| # Try to delete existing file first | |
| try: | |
| self.session.delete(url, timeout=60) | |
| except: | |
| pass | |
| # Now upload with retry | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| with open(local_path, 'rb') as f: | |
| response = self.session.put(url, data=f, timeout=60) | |
| response.raise_for_status() | |
| return | |
| except requests.exceptions.Timeout: | |
| if attempt == max_retries - 1: | |
| raise | |
| self.log(f"Upload timeout on attempt {attempt + 1}, retrying...") | |
| continue | |
| def delete_file(self, filename): | |
| """Delete a file from NextCloud""" | |
| url = self.get_webdav_url(filename) | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.session.delete(url, timeout=60) | |
| response.raise_for_status() | |
| return True | |
| except requests.exceptions.Timeout: | |
| if attempt == max_retries - 1: | |
| raise | |
| self.log(f"Delete timeout on attempt {attempt + 1}, retrying...") | |
| continue | |
| except Exception as e: | |
| self.log(f"Warning: Could not delete {filename}: {e}") | |
| return False | |
| def list_files(self, remote_path=""): | |
| url = self.get_webdav_url(remote_path) | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| response = self.session.request('PROPFIND', url, headers={'Depth': '1'}, timeout=60) | |
| response.raise_for_status() | |
| files = [] | |
| lines = response.text.split('<d:href>') | |
| for line in lines: | |
| if '</d:href>' in line: | |
| href = line.split('</d:href>')[0] | |
| if '/webdav/' in href: | |
| filename = href.split('/webdav/')[-1] | |
| if filename and not filename.endswith('/'): | |
| files.append(filename) | |
| return files | |
| except requests.exceptions.Timeout: | |
| if attempt == max_retries - 1: | |
| raise | |
| self.log(f"List files timeout on attempt {attempt + 1}, retrying...") | |
| continue | |
| def create_folder(self, folder_name): | |
| url = self.get_webdav_url(folder_name) | |
| try: | |
| self.session.request('MKCOL', url, timeout=60) | |
| except: | |
| pass | |
| def get_image_list(self): | |
| self.log("Filtering images...") | |
| valid_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'} | |
| self.images = [f for f in self.all_files if Path(f).suffix.lower() in valid_extensions] | |
| self.images.sort() | |
| self.log(f"β Found {len(self.images)} images to classify") | |
| def load_deep_analysis(self): | |
| self.log("Looking for deep_training_analysis.json...") | |
| # Check local file first | |
| local_json_path = "deep_training_analysis.json" | |
| if os.path.exists(local_json_path): | |
| self.log(f"Found local file: {local_json_path}") | |
| try: | |
| with open(local_json_path, 'r') as f: | |
| self.deep_analysis = json.load(f) | |
| self.log("π Loaded deep training analysis from local file") | |
| self.category_embeddings = {} | |
| for category in self.categories: | |
| if category in self.deep_analysis: | |
| data = self.deep_analysis[category] | |
| avg_embedding = torch.tensor(data['avg_embedding'], dtype=torch.float32).to(self.device) | |
| avg_embedding = avg_embedding / avg_embedding.norm() | |
| self.category_embeddings[category] = avg_embedding | |
| self.log(f" {category}: {data['num_training_images']} training images") | |
| else: | |
| self.category_embeddings[category] = self.create_text_embedding(category) | |
| return | |
| except Exception as e: | |
| self.log(f"β Error loading local deep analysis: {e}") | |
| # Fallback to text embeddings | |
| self.log("β οΈ deep_training_analysis.json not found - using fallback embeddings") | |
| self.category_embeddings = {cat: self.create_text_embedding(cat) for cat in self.categories} | |
| def create_text_embedding(self, category): | |
| descriptions = { | |
| "1_Booth": "a photo of an exhibition booth at a trade show", | |
| "2_Business_Interaction": "a photo of business people talking at a trade show", | |
| "3_Buyer_Delegation": "a photo of a large group visiting a trade show", | |
| "4_Aisle": "a photo of a trade show aisle between booths", | |
| "5_Conference": "a photo of a conference presentation or seminar", | |
| "6_Fairground": "a photo of an exhibition hall or fairground", | |
| "7_Products": "a photo of products on display", | |
| "8_Registration": "a photo of a registration desk or entry gate", | |
| "9_Miscellaneous": "a miscellaneous trade show photo" | |
| } | |
| text = descriptions.get(category, "a photo") | |
| text_input = clip.tokenize([text]).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.model.encode_text(text_input) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features[0] | |
| def classify_image(self, filename): | |
| try: | |
| # Download the image | |
| img_data = self.download_file(filename) | |
| # Classify it | |
| img = Image.open(BytesIO(img_data)).convert('RGB') | |
| img_input = self.preprocess(img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| img_features = self.model.encode_image(img_input) | |
| img_features = img_features / img_features.norm(dim=-1, keepdim=True) | |
| img_features = img_features[0] | |
| similarities = {} | |
| for category, cat_embedding in self.category_embeddings.items(): | |
| similarity = (img_features @ cat_embedding).item() | |
| similarities[category] = similarity | |
| best_category = max(similarities, key=similarities.get) | |
| confidence = similarities[best_category] | |
| # Save to temp file | |
| local_path = os.path.join(self.temp_dir, Path(filename).name) | |
| with open(local_path, 'wb') as f: | |
| f.write(img_data) | |
| # Create category folder structure | |
| category_folder = f"Classified/{best_category}" | |
| self.create_folder("Classified") | |
| self.create_folder(category_folder) | |
| # Upload to new location | |
| remote_dest = f"{category_folder}/{Path(filename).name}" | |
| self.upload_file(local_path, remote_dest) | |
| # **MOVE instead of COPY: Delete the original file** | |
| self.delete_file(filename) | |
| # Clean up temp file | |
| os.remove(local_path) | |
| return best_category, confidence | |
| except Exception as e: | |
| self.log(f"β Error processing {filename}: {str(e)}") | |
| return "9_Miscellaneous", 0.0 | |
| def run(self): | |
| self.log("π Starting classification...") | |
| # Create category folders | |
| self.create_folder("Classified") | |
| for cat in self.categories: | |
| self.create_folder(f"Classified/{cat}") | |
| stats = {cat: 0 for cat in self.categories} | |
| confidences = {cat: [] for cat in self.categories} | |
| for i, filename in enumerate(self.images, 1): | |
| self.log(f"[{i}/{len(self.images)}] Processing {Path(filename).name}...") | |
| category, confidence = self.classify_image(filename) | |
| stats[category] += 1 | |
| confidences[category].append(confidence) | |
| self.log(f" β Moved to {category} (confidence: {confidence:.3f})") | |
| self.log("β CLASSIFICATION COMPLETE!") | |
| self.log("π Results moved to: Classified/") | |
| self.log("π Original files have been removed from root folder") | |
| result_text = "## β Classification Complete!\n\n**Files have been MOVED (not copied) to categorized folders**\n\n**Results Summary:**\n\n" | |
| for cat in self.categories: | |
| count = stats[cat] | |
| if count > 0: | |
| avg_conf = sum(confidences[cat]) / len(confidences[cat]) | |
| result_text += f"- **{cat}**: {count} images (avg confidence: {avg_conf:.3f})\n" | |
| # Clean up temp directory | |
| shutil.rmtree(self.temp_dir) | |
| return result_text | |
| def classify_photos(share_url, share_password, progress=gr.Progress()): | |
| if not share_url or not share_password: | |
| return "β Please enter both the share URL and password", "" | |
| logs = [] | |
| def log_callback(message): | |
| logs.append(message) | |
| log_display = "\n".join(logs[-20:]) | |
| return log_display | |
| try: | |
| progress(0, desc="Initializing...") | |
| progress(0.1, desc="Connecting to NextCloud...") | |
| classifier = SmartCLIPClassifierNextCloudShare( | |
| share_url, | |
| share_password, | |
| progress_callback=log_callback | |
| ) | |
| progress(0.3, desc=f"Found {len(classifier.images)} images to classify") | |
| # Run classification | |
| result = classifier.run() | |
| progress(1.0, desc="Complete!") | |
| return result, "\n".join(logs) | |
| except requests.exceptions.Timeout: | |
| error_msg = "β Connection Timeout: NextCloud is taking too long to respond.\n\nPlease check your network connection and try again." | |
| return error_msg, "\n".join(logs) | |
| except requests.exceptions.RequestException as e: | |
| error_msg = f"β Connection Error: Could not connect to NextCloud.\n\nPlease check:\n- Your share URL is correct\n- Your password is correct\n- The share link has 'Allow upload and editing' enabled\n\nError details: {str(e)}" | |
| return error_msg, "\n".join(logs) | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}\n\nPlease check your share URL and password and try again." | |
| return error_msg, "\n".join(logs) | |
| # Gradio Interface | |
| with gr.Blocks(title="Trade Show Photo Classifier", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π€ Trade Show Photo Classifier") | |
| gr.Markdown("Automatically classify your trade show photos using AI-powered image recognition") | |
| gr.Markdown(""" | |
| ### π Setup Instructions: | |
| 1. Upload your photos to a NextCloud folder | |
| 2. Create a public share link for that folder | |
| 3. **Important:** When creating the share, enable **"Allow upload and editing"** (click the three dots β Share settings) | |
| 4. Set a password for the share | |
| 5. Copy the share URL and password below | |
| 6. Click "Start Classification" | |
| **β οΈ Note: Files will be MOVED (not copied) to the Classified folder. Original files will be deleted from the root.** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| share_url = gr.Textbox( | |
| label="NextCloud Share URL", | |
| placeholder="https://cloud2.messefrankfurtexchange.com/s/...", | |
| info="Enter the public share link to your NextCloud folder" | |
| ) | |
| share_password = gr.Textbox( | |
| label="Share Password", | |
| type="password", | |
| info="Enter the password for the NextCloud share" | |
| ) | |
| classify_btn = gr.Button("π Start Classification", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Results") | |
| output = gr.Markdown() | |
| with gr.Column(): | |
| gr.Markdown("### π Classification Log") | |
| logs_output = gr.Textbox( | |
| label="Progress", | |
| lines=15, | |
| max_lines=20, | |
| interactive=False, | |
| show_label=False | |
| ) | |
| classify_btn.click( | |
| fn=classify_photos, | |
| inputs=[share_url, share_password], | |
| outputs=[output, logs_output] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| **Categories:** | |
| - 1_Booth: Exhibition booths | |
| - 2_Business_Interaction: Business conversations | |
| - 3_Buyer_Delegation: Group visits | |
| - 4_Aisle: Walkways between booths | |
| - 5_Conference: Presentations and seminars | |
| - 6_Fairground: Exhibition halls | |
| - 7_Products: Product displays | |
| - 8_Registration: Entry and registration areas | |
| - 9_Miscellaneous: Other trade show content | |
| *Powered by OpenAI CLIP | Deployed on Hugging Face Spaces* | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |