michaelozon's picture
Update app.py
05d1b28 verified
import gradio as gr
import torch
import numpy as np
import pickle
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter
from datasets import load_dataset
print("Loading CLIP model...")
model_name = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).to(device)
model.eval()
print(f"Model loaded on {device}")
print("Loading recommendation system...")
with open("recommendation_system.pkl", "rb") as f:
rec_system = pickle.load(f)
embeddings = np.asarray(rec_system["embeddings"])
df_metadata = rec_system["df_metadata"]
label_to_disease = rec_system.get("label_to_disease", {})
print(f"Loaded {len(embeddings)} samples")
# -----------------------------
# Load HF dataset for image retrieval (to show similar images)
# -----------------------------
print("Loading Hugging Face dataset (for gallery images)...")
ds = load_dataset("wellCh4n/tomato-leaf-disease-image")
ds_train = ds.get("train", None)
ds_val = ds.get("validation", None)
ds_test = ds.get("test", None)
# -----------------------------
# Optional informational text (NOT medical advice)
# -----------------------------
disease_info = {
"Healthy": {
"description": "Your tomato leaf appears healthy! No obvious signs of disease.",
"symptoms": "Vibrant green color, smooth texture, no spots or discoloration",
"action": "Continue regular care and monitoring"
},
"Leaf Mold": {
"description": "Leaf Mold is a fungal disease causing yellowish patches on leaves.",
"symptoms": "Yellow spots on upper leaf surface, fuzzy growth underneath",
"action": "Improve air circulation, reduce humidity"
},
"Target Spot": {
"description": "Target Spot can create concentric ring patterns on leaves.",
"symptoms": "Brown spots with ring-like (bullseye) patterns",
"action": "Remove affected leaves, improve plant hygiene"
},
"Late Blight": {
"description": "Late Blight can spread quickly under humid conditions.",
"symptoms": "Water-soaked lesions, dark brown patches, possible mold in humidity",
"action": "Remove severely affected leaves, reduce leaf wetness"
},
"Early Blight": {
"description": "Early Blight often appears on older leaves with dark spots.",
"symptoms": "Dark brown spots with target-like rings, yellowing around lesions",
"action": "Remove lower affected leaves, improve spacing"
},
"Bacterial Spot": {
"description": "Bacterial Spot causes small dark spots on leaves.",
"symptoms": "Small dark lesions, sometimes with yellow halos",
"action": "Avoid overhead watering, improve airflow"
},
"Septoria Leaf Spot": {
"description": "Septoria shows small circular spots with gray centers.",
"symptoms": "Numerous small spots with dark borders and gray centers",
"action": "Remove infected leaves, avoid wetting foliage"
},
"Yellow Curl Virus": {
"description": "Yellow Leaf Curl Virus can cause leaf curling and yellowing.",
"symptoms": "Upward curling, yellowing, stunted growth",
"action": "Control insect vectors, remove infected plants"
},
"Spider Mites": {
"description": "Spider Mites cause stippling and bronzing of leaves.",
"symptoms": "Tiny pale spots, possible fine webbing",
"action": "Rinse leaves, increase humidity, consider insecticidal soap"
}
}
# -----------------------------
# Core: Embeddings
# -----------------------------
def generate_embedding(image: Image.Image) -> np.ndarray:
"""Generate a normalized CLIP image embedding (shape: (512,))."""
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
image_features = model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()[0]
def find_similar_cases(image: Image.Image, top_k: int = 3, unique_diseases: bool = False):
"""Return top-k similar items based on cosine similarity."""
query_emb = generate_embedding(image).reshape(1, -1)
similarities = cosine_similarity(query_emb, embeddings)[0]
ranked_idx = np.argsort(similarities)[::-1]
results = []
seen = set()
for idx in ranked_idx:
disease = df_metadata.iloc[idx].get("disease_name", "Unknown")
if unique_diseases and disease in seen:
continue
seen.add(disease)
results.append({
"index": int(idx),
"disease": disease,
"similarity": float(similarities[idx]),
"text": df_metadata.iloc[idx].get("text", "")
})
if len(results) >= top_k:
break
return results
def majority_vote_prediction(image: Image.Image, vote_k: int = 5):
"""Predict label using majority vote over top-k nearest neighbors (by images)."""
neighbors = find_similar_cases(image, top_k=vote_k, unique_diseases=False)
labels = [n["disease"] for n in neighbors]
pred, count = Counter(labels).most_common(1)[0]
support = count / len(labels)
return pred, support, neighbors
# -----------------------------
# Image retrieval for gallery
# -----------------------------
def _get_dataset_image_for_result(result_row):
"""
Try to retrieve the actual image from HF dataset for a result.
Works best if df_metadata contains split+row_id (or hf_idx).
Fallback: use result_row['index'] as index into train split.
"""
idx = result_row["index"]
split_col_candidates = ["split", "hf_split", "dataset_split"]
rowid_col_candidates = ["row_id", "hf_idx", "hf_index", "dataset_idx", "original_index"]
split_val = None
row_id = None
for c in split_col_candidates:
if c in df_metadata.columns:
split_val = df_metadata.iloc[idx][c]
break
for c in rowid_col_candidates:
if c in df_metadata.columns:
row_id = df_metadata.iloc[idx][c]
break
if split_val is not None and row_id is not None:
split_val = str(split_val).strip().lower()
row_id = int(row_id)
if split_val == "train" and ds_train is not None:
return ds_train[row_id]["image"]
if split_val in ["validation", "val"] and ds_val is not None:
return ds_val[row_id]["image"]
if split_val == "test" and ds_test is not None:
return ds_test[row_id]["image"]
if ds_train is not None and idx < len(ds_train):
return ds_train[idx]["image"]
return None
# -----------------------------
# UI function
# -----------------------------
def diagnose_tomato_leaf(image, unique_alt: bool):
if image is None:
return "❗ Please upload an image first.", "", "", [], f"Processed on {device.upper()}"
# Resize input to keep inference light
max_size = 512
if max(image.size) > max_size:
image = image.copy()
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
pred_label, support, neighbors5 = majority_vote_prediction(image, vote_k=5)
top3 = find_similar_cases(image, top_k=3, unique_diseases=bool(unique_alt))
best_sim = top3[0]["similarity"] * 100
diagnosis_md = (
f"## Result\n"
f"**Predicted disease (majority vote over top-5):** **{pred_label}** \n"
f"**Support:** {support*100:.1f}% (how many of the top-5 neighbors match this label) \n"
f"**Best-match similarity:** {best_sim:.1f}% \n\n"
f"> This is a **similarity-based retrieval tool** (CLIP embeddings + cosine similarity) for educational use.\n"
)
if pred_label in disease_info:
info = disease_info[pred_label]
diagnosis_md += (
f"\n### General Info (not medical advice)\n"
f"**Description:** {info['description']} \n"
f"**Typical symptoms:** {info['symptoms']} \n"
f"**General action:** {info['action']} \n"
)
cases_md = "## Top Similar Cases (Top-3)\n\n"
for i, r in enumerate(top3, 1):
cases_md += f"**{i}. {r['disease']}** β€” Similarity: {r['similarity']*100:.2f}%\n\n"
technical = "## Technical Details\n\n"
technical += f"- Model: CLIP (ViT-B/32)\n"
technical += f"- Embedding dimension: {embeddings.shape[1] if len(embeddings.shape) == 2 else 'Unknown'}\n"
technical += f"- Similarity metric: Cosine similarity\n\n"
technical += "### Top-5 neighbors used for majority vote\n"
for i, r in enumerate(neighbors5, 1):
technical += f"{i}. **{r['disease']}**: {r['similarity']*100:.2f}%\n"
# IMPORTANT: return gallery in a stable format
gallery_items = []
for r in top3:
img = _get_dataset_image_for_result(r)
if img is not None:
gallery_items.append((img, f"{r['disease']} ({r['similarity']*100:.1f}%)"))
status = f"βœ… Analysis complete! Processed on {device.upper()}"
return diagnosis_md, cases_md, technical, gallery_items, status
# -----------------------------
# Gradio app
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="πŸ… Tomato Disease Detector") as demo:
# Header with video link
gr.Markdown(
"# πŸ… Tomato Leaf Similarity & Disease Finder\n"
"Upload a tomato leaf image and retrieve the most similar labeled cases from the dataset.\n\n"
"**Note:** This is a similarity-based tool for educational purposes (not professional diagnosis)."
)
# Prominent video link box
gr.Markdown(
"""
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 20px;
border-radius: 10px;
text-align: center;
margin: 20px 0;">
<h2 style="color: white; margin: 0 0 10px 0;">🎬 Watch Project Presentation</h2>
<p style="color: #f0f0f0; margin: 0 0 15px 0;">Complete walkthrough in Hebrew (3-5 minutes)</p>
<a href="https://drive.google.com/drive/folders/1IoUJWKOcHUc6m53uWl5CIHLHRscT-xrS?usp=sharing"
target="_blank"
style="background: white;
color: #667eea;
padding: 12px 30px;
border-radius: 25px;
text-decoration: none;
font-weight: bold;
font-size: 16px;
display: inline-block;">
πŸ“Ί Watch Video
</a>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Tomato Leaf Image", height=380)
unique_alt = gr.Checkbox(
value=False,
label="Show 3 different diseases (unique labels) in Top-3"
)
diagnose_btn = gr.Button("Analyze", variant="primary", size="lg")
gr.Markdown(
"### Tips\n"
"- Use clear, well-lit photos\n"
"- Focus on the affected area\n"
"- Avoid blurry images\n"
)
with gr.Column(scale=2):
diagnosis_output = gr.Markdown()
cases_output = gr.Markdown()
gallery_output = gr.Gallery(
label="Top-3 Similar Images",
columns=3,
height=220
)
with gr.Accordion("Technical Details", open=False):
technical_output = gr.Markdown()
status_output = gr.Markdown()
diagnose_btn.click(
fn=diagnose_tomato_leaf,
inputs=[image_input, unique_alt],
outputs=[diagnosis_output, cases_output, technical_output, gallery_output, status_output],
api_name=False,
)
# Footer with project info
gr.Markdown(
"""
---
### πŸ“Š Project Information
**Assignment #3:** Embeddings, RecSys, Spaces
**Author:** Michael Ozon
**Technologies:** CLIP ViT-B/32 β€’ Gradio β€’ HuggingFace β€’ scikit-learn
**Dataset:** 14,218 tomato images | **Embeddings:** 512-dim | **Method:** Cosine similarity + Majority voting
**Built with:** πŸ€— HuggingFace β€’ 🎨 Gradio β€’ 🧠 OpenAI CLIP β€’ πŸ“Š scikit-learn
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)