Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pickle | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from collections import Counter | |
| from datasets import load_dataset | |
| print("Loading CLIP model...") | |
| model_name = "openai/clip-vit-base-patch32" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| model = CLIPModel.from_pretrained(model_name).to(device) | |
| model.eval() | |
| print(f"Model loaded on {device}") | |
| print("Loading recommendation system...") | |
| with open("recommendation_system.pkl", "rb") as f: | |
| rec_system = pickle.load(f) | |
| embeddings = np.asarray(rec_system["embeddings"]) | |
| df_metadata = rec_system["df_metadata"] | |
| label_to_disease = rec_system.get("label_to_disease", {}) | |
| print(f"Loaded {len(embeddings)} samples") | |
| # ----------------------------- | |
| # Load HF dataset for image retrieval (to show similar images) | |
| # ----------------------------- | |
| print("Loading Hugging Face dataset (for gallery images)...") | |
| ds = load_dataset("wellCh4n/tomato-leaf-disease-image") | |
| ds_train = ds.get("train", None) | |
| ds_val = ds.get("validation", None) | |
| ds_test = ds.get("test", None) | |
| # ----------------------------- | |
| # Optional informational text (NOT medical advice) | |
| # ----------------------------- | |
| disease_info = { | |
| "Healthy": { | |
| "description": "Your tomato leaf appears healthy! No obvious signs of disease.", | |
| "symptoms": "Vibrant green color, smooth texture, no spots or discoloration", | |
| "action": "Continue regular care and monitoring" | |
| }, | |
| "Leaf Mold": { | |
| "description": "Leaf Mold is a fungal disease causing yellowish patches on leaves.", | |
| "symptoms": "Yellow spots on upper leaf surface, fuzzy growth underneath", | |
| "action": "Improve air circulation, reduce humidity" | |
| }, | |
| "Target Spot": { | |
| "description": "Target Spot can create concentric ring patterns on leaves.", | |
| "symptoms": "Brown spots with ring-like (bullseye) patterns", | |
| "action": "Remove affected leaves, improve plant hygiene" | |
| }, | |
| "Late Blight": { | |
| "description": "Late Blight can spread quickly under humid conditions.", | |
| "symptoms": "Water-soaked lesions, dark brown patches, possible mold in humidity", | |
| "action": "Remove severely affected leaves, reduce leaf wetness" | |
| }, | |
| "Early Blight": { | |
| "description": "Early Blight often appears on older leaves with dark spots.", | |
| "symptoms": "Dark brown spots with target-like rings, yellowing around lesions", | |
| "action": "Remove lower affected leaves, improve spacing" | |
| }, | |
| "Bacterial Spot": { | |
| "description": "Bacterial Spot causes small dark spots on leaves.", | |
| "symptoms": "Small dark lesions, sometimes with yellow halos", | |
| "action": "Avoid overhead watering, improve airflow" | |
| }, | |
| "Septoria Leaf Spot": { | |
| "description": "Septoria shows small circular spots with gray centers.", | |
| "symptoms": "Numerous small spots with dark borders and gray centers", | |
| "action": "Remove infected leaves, avoid wetting foliage" | |
| }, | |
| "Yellow Curl Virus": { | |
| "description": "Yellow Leaf Curl Virus can cause leaf curling and yellowing.", | |
| "symptoms": "Upward curling, yellowing, stunted growth", | |
| "action": "Control insect vectors, remove infected plants" | |
| }, | |
| "Spider Mites": { | |
| "description": "Spider Mites cause stippling and bronzing of leaves.", | |
| "symptoms": "Tiny pale spots, possible fine webbing", | |
| "action": "Rinse leaves, increase humidity, consider insecticidal soap" | |
| } | |
| } | |
| # ----------------------------- | |
| # Core: Embeddings | |
| # ----------------------------- | |
| def generate_embedding(image: Image.Image) -> np.ndarray: | |
| """Generate a normalized CLIP image embedding (shape: (512,)).""" | |
| with torch.no_grad(): | |
| inputs = processor(images=image, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| image_features = model.get_image_features(**inputs) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy()[0] | |
| def find_similar_cases(image: Image.Image, top_k: int = 3, unique_diseases: bool = False): | |
| """Return top-k similar items based on cosine similarity.""" | |
| query_emb = generate_embedding(image).reshape(1, -1) | |
| similarities = cosine_similarity(query_emb, embeddings)[0] | |
| ranked_idx = np.argsort(similarities)[::-1] | |
| results = [] | |
| seen = set() | |
| for idx in ranked_idx: | |
| disease = df_metadata.iloc[idx].get("disease_name", "Unknown") | |
| if unique_diseases and disease in seen: | |
| continue | |
| seen.add(disease) | |
| results.append({ | |
| "index": int(idx), | |
| "disease": disease, | |
| "similarity": float(similarities[idx]), | |
| "text": df_metadata.iloc[idx].get("text", "") | |
| }) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| def majority_vote_prediction(image: Image.Image, vote_k: int = 5): | |
| """Predict label using majority vote over top-k nearest neighbors (by images).""" | |
| neighbors = find_similar_cases(image, top_k=vote_k, unique_diseases=False) | |
| labels = [n["disease"] for n in neighbors] | |
| pred, count = Counter(labels).most_common(1)[0] | |
| support = count / len(labels) | |
| return pred, support, neighbors | |
| # ----------------------------- | |
| # Image retrieval for gallery | |
| # ----------------------------- | |
| def _get_dataset_image_for_result(result_row): | |
| """ | |
| Try to retrieve the actual image from HF dataset for a result. | |
| Works best if df_metadata contains split+row_id (or hf_idx). | |
| Fallback: use result_row['index'] as index into train split. | |
| """ | |
| idx = result_row["index"] | |
| split_col_candidates = ["split", "hf_split", "dataset_split"] | |
| rowid_col_candidates = ["row_id", "hf_idx", "hf_index", "dataset_idx", "original_index"] | |
| split_val = None | |
| row_id = None | |
| for c in split_col_candidates: | |
| if c in df_metadata.columns: | |
| split_val = df_metadata.iloc[idx][c] | |
| break | |
| for c in rowid_col_candidates: | |
| if c in df_metadata.columns: | |
| row_id = df_metadata.iloc[idx][c] | |
| break | |
| if split_val is not None and row_id is not None: | |
| split_val = str(split_val).strip().lower() | |
| row_id = int(row_id) | |
| if split_val == "train" and ds_train is not None: | |
| return ds_train[row_id]["image"] | |
| if split_val in ["validation", "val"] and ds_val is not None: | |
| return ds_val[row_id]["image"] | |
| if split_val == "test" and ds_test is not None: | |
| return ds_test[row_id]["image"] | |
| if ds_train is not None and idx < len(ds_train): | |
| return ds_train[idx]["image"] | |
| return None | |
| # ----------------------------- | |
| # UI function | |
| # ----------------------------- | |
| def diagnose_tomato_leaf(image, unique_alt: bool): | |
| if image is None: | |
| return "β Please upload an image first.", "", "", [], f"Processed on {device.upper()}" | |
| # Resize input to keep inference light | |
| max_size = 512 | |
| if max(image.size) > max_size: | |
| image = image.copy() | |
| image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
| pred_label, support, neighbors5 = majority_vote_prediction(image, vote_k=5) | |
| top3 = find_similar_cases(image, top_k=3, unique_diseases=bool(unique_alt)) | |
| best_sim = top3[0]["similarity"] * 100 | |
| diagnosis_md = ( | |
| f"## Result\n" | |
| f"**Predicted disease (majority vote over top-5):** **{pred_label}** \n" | |
| f"**Support:** {support*100:.1f}% (how many of the top-5 neighbors match this label) \n" | |
| f"**Best-match similarity:** {best_sim:.1f}% \n\n" | |
| f"> This is a **similarity-based retrieval tool** (CLIP embeddings + cosine similarity) for educational use.\n" | |
| ) | |
| if pred_label in disease_info: | |
| info = disease_info[pred_label] | |
| diagnosis_md += ( | |
| f"\n### General Info (not medical advice)\n" | |
| f"**Description:** {info['description']} \n" | |
| f"**Typical symptoms:** {info['symptoms']} \n" | |
| f"**General action:** {info['action']} \n" | |
| ) | |
| cases_md = "## Top Similar Cases (Top-3)\n\n" | |
| for i, r in enumerate(top3, 1): | |
| cases_md += f"**{i}. {r['disease']}** β Similarity: {r['similarity']*100:.2f}%\n\n" | |
| technical = "## Technical Details\n\n" | |
| technical += f"- Model: CLIP (ViT-B/32)\n" | |
| technical += f"- Embedding dimension: {embeddings.shape[1] if len(embeddings.shape) == 2 else 'Unknown'}\n" | |
| technical += f"- Similarity metric: Cosine similarity\n\n" | |
| technical += "### Top-5 neighbors used for majority vote\n" | |
| for i, r in enumerate(neighbors5, 1): | |
| technical += f"{i}. **{r['disease']}**: {r['similarity']*100:.2f}%\n" | |
| # IMPORTANT: return gallery in a stable format | |
| gallery_items = [] | |
| for r in top3: | |
| img = _get_dataset_image_for_result(r) | |
| if img is not None: | |
| gallery_items.append((img, f"{r['disease']} ({r['similarity']*100:.1f}%)")) | |
| status = f"β Analysis complete! Processed on {device.upper()}" | |
| return diagnosis_md, cases_md, technical, gallery_items, status | |
| # ----------------------------- | |
| # Gradio app | |
| # ----------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="π Tomato Disease Detector") as demo: | |
| # Header with video link | |
| gr.Markdown( | |
| "# π Tomato Leaf Similarity & Disease Finder\n" | |
| "Upload a tomato leaf image and retrieve the most similar labeled cases from the dataset.\n\n" | |
| "**Note:** This is a similarity-based tool for educational purposes (not professional diagnosis)." | |
| ) | |
| # Prominent video link box | |
| gr.Markdown( | |
| """ | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 20px; | |
| border-radius: 10px; | |
| text-align: center; | |
| margin: 20px 0;"> | |
| <h2 style="color: white; margin: 0 0 10px 0;">π¬ Watch Project Presentation</h2> | |
| <p style="color: #f0f0f0; margin: 0 0 15px 0;">Complete walkthrough in Hebrew (3-5 minutes)</p> | |
| <a href="https://drive.google.com/drive/folders/1IoUJWKOcHUc6m53uWl5CIHLHRscT-xrS?usp=sharing" | |
| target="_blank" | |
| style="background: white; | |
| color: #667eea; | |
| padding: 12px 30px; | |
| border-radius: 25px; | |
| text-decoration: none; | |
| font-weight: bold; | |
| font-size: 16px; | |
| display: inline-block;"> | |
| πΊ Watch Video | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Tomato Leaf Image", height=380) | |
| unique_alt = gr.Checkbox( | |
| value=False, | |
| label="Show 3 different diseases (unique labels) in Top-3" | |
| ) | |
| diagnose_btn = gr.Button("Analyze", variant="primary", size="lg") | |
| gr.Markdown( | |
| "### Tips\n" | |
| "- Use clear, well-lit photos\n" | |
| "- Focus on the affected area\n" | |
| "- Avoid blurry images\n" | |
| ) | |
| with gr.Column(scale=2): | |
| diagnosis_output = gr.Markdown() | |
| cases_output = gr.Markdown() | |
| gallery_output = gr.Gallery( | |
| label="Top-3 Similar Images", | |
| columns=3, | |
| height=220 | |
| ) | |
| with gr.Accordion("Technical Details", open=False): | |
| technical_output = gr.Markdown() | |
| status_output = gr.Markdown() | |
| diagnose_btn.click( | |
| fn=diagnose_tomato_leaf, | |
| inputs=[image_input, unique_alt], | |
| outputs=[diagnosis_output, cases_output, technical_output, gallery_output, status_output], | |
| api_name=False, | |
| ) | |
| # Footer with project info | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π Project Information | |
| **Assignment #3:** Embeddings, RecSys, Spaces | |
| **Author:** Michael Ozon | |
| **Technologies:** CLIP ViT-B/32 β’ Gradio β’ HuggingFace β’ scikit-learn | |
| **Dataset:** 14,218 tomato images | **Embeddings:** 512-dim | **Method:** Cosine similarity + Majority voting | |
| **Built with:** π€ HuggingFace β’ π¨ Gradio β’ π§ OpenAI CLIP β’ π scikit-learn | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |