| import streamlit as st |
| import numpy as np |
| from PIL import Image |
| import json |
| import tensorflow as tf |
| import keras |
| import cv2 |
| import os |
| from transformers import AutoModelForCausalLM |
| from peft import PeftModel |
|
|
| |
| |
| |
| st.set_page_config(page_title="Tomato AI 🍅", layout="wide") |
|
|
| st.markdown(""" |
| <style> |
| .big-title { font-size:42px; font-weight:bold; color:#2E8B57; } |
| .card { padding:20px; border-radius:15px; background-color:#f5f5f5; margin-bottom:15px; } |
| .result { font-size:26px; font-weight:bold; color:#ff4b4b; } |
| .sub { font-size:18px; color:#555; } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| st.markdown('<div class="big-title">🍅 AI Tomato Disease Detection System</div>', unsafe_allow_html=True) |
|
|
| |
| |
| |
| GEMMA_PATH = "/app/gemma-tomato-lora" |
| MODEL_PATH = "/app/src/best_model.h5" |
| JSON_PATH = "/app/src/class_indices.json" |
|
|
| |
| |
| |
| hf_token = os.environ.get("HF_TOKEN") |
|
|
| if not hf_token: |
| st.error("HF_TOKEN tapılmadı!") |
| st.stop() |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_cnn(): |
| model = keras.models.load_model(MODEL_PATH, compile=False) |
| return model |
|
|
| cnn_model = load_cnn() |
|
|
| |
| |
| |
| with open(JSON_PATH) as f: |
| class_indices = json.load(f) |
|
|
| class_names = [None] * len(class_indices) |
| for k, v in class_indices.items(): |
| class_names[v] = k |
|
|
| |
| |
| |
| IMG_SIZE = 224 |
|
|
| def preprocess(img): |
| img = img.convert("RGB") |
| img = img.resize((IMG_SIZE, IMG_SIZE)) |
| arr = np.array(img, dtype=np.float32) |
| return np.expand_dims(arr, axis=0) |
|
|
| |
| |
| |
| @st.cache_resource |
| def load_llm(_token): |
| from transformers import GemmaTokenizer |
|
|
| tok = GemmaTokenizer( |
| vocab_file=f"{GEMMA_PATH}/tokenizer.model", |
| add_bos_token=True, |
| add_eos_token=False, |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| "google/gemma-2b-it", |
| device_map="auto", |
| torch_dtype="auto", |
| token=_token |
| ) |
|
|
| model = PeftModel.from_pretrained(model, GEMMA_PATH) |
| return tok, model |
|
|
| tokenizer, gemma = load_llm(hf_token) |
|
|
| def explain(label): |
| prompt = f"""You are a plant disease expert. Explain the following tomato disease clearly and concisely in English. |
| Disease: {label} |
| Provide: |
| 1. Cause |
| 2. Symptoms |
| 3. Prevention |
| 4. Treatment advice for farmers |
| Answer:""" |
| inputs = tokenizer(prompt, return_tensors="pt").to(gemma.device) |
| out = gemma.generate( |
| **inputs, |
| max_new_tokens=200, |
| do_sample=True, |
| temperature=0.7, |
| repetition_penalty=1.3 |
| ) |
| full = tokenizer.decode(out[0], skip_special_tokens=True) |
| return full[len(prompt):] |
|
|
| |
| |
| |
| def make_gradcam(img_array, model): |
| last_conv_name = None |
| search_model = model |
|
|
| for layer in model.layers: |
| if isinstance(layer, keras.Model): |
| search_model = layer |
| for inner_layer in layer.layers: |
| if isinstance(inner_layer, keras.Model): |
| search_model = inner_layer |
| break |
| break |
|
|
| for layer in reversed(search_model.layers): |
| layer_type = type(layer).__name__ |
| if "Conv2D" in layer_type or "conv" in layer.name.lower(): |
| last_conv_name = layer.name |
| break |
|
|
| if last_conv_name is None: |
| layer_names = [f"{l.name} ({type(l).__name__})" for l in search_model.layers[-20:]] |
| raise ValueError("Conv2D tapılmadı:\n" + "\n".join(layer_names)) |
|
|
| try: |
| grad_model = keras.models.Model( |
| inputs=search_model.inputs, |
| outputs=[search_model.get_layer(last_conv_name).output, search_model.output] |
| ) |
| except Exception: |
| grad_model = keras.models.Model( |
| inputs=model.inputs, |
| outputs=[search_model.get_layer(last_conv_name).output, model.output] |
| ) |
|
|
| _ = grad_model(img_array) |
|
|
| with tf.GradientTape() as tape: |
| conv_outputs, preds = grad_model(img_array) |
| loss = tf.reduce_max(preds) |
|
|
| grads = tape.gradient(loss, conv_outputs) |
| pooled = tf.reduce_mean(grads, axis=(0, 1, 2)) |
|
|
| heatmap = conv_outputs[0] @ pooled[..., tf.newaxis] |
| heatmap = tf.squeeze(heatmap) |
| heatmap = np.maximum(heatmap.numpy(), 0) |
|
|
| max_val = np.max(heatmap) |
| if max_val > 0: |
| heatmap /= max_val |
|
|
| return heatmap, last_conv_name |
|
|
| |
| |
| |
| uploaded_file = st.file_uploader("Upload tomato leaf image 🍅", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file: |
| image = Image.open(uploaded_file) |
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| st.image(image, caption="Input Image", use_column_width=True) |
|
|
| img = preprocess(image) |
| preds = cnn_model.predict(img, verbose=0) |
| idx = np.argmax(preds[0]) |
| label = class_names[idx] |
| conf = float(preds[0][idx]) |
|
|
| with col2: |
| st.markdown('<div class="card">', unsafe_allow_html=True) |
| st.markdown(f"<div class='result'>🔍 Prediction: {label}</div>", unsafe_allow_html=True) |
| st.markdown(f"<div class='sub'>Confidence: {conf:.2%}</div>", unsafe_allow_html=True) |
| st.progress(conf) |
|
|
| st.markdown("#### Top 3 Predictions") |
| top3_idx = np.argsort(preds[0])[::-1][:3] |
| for i in top3_idx: |
| st.markdown(f"- **{class_names[i]}**: {preds[0][i]:.2%}") |
|
|
| st.markdown('</div>', unsafe_allow_html=True) |
|
|
| st.markdown("### 🧠 AI Explanation (Gemma)") |
| with st.spinner("Generating explanation..."): |
| try: |
| explanation = explain(label) |
| st.info(explanation) |
| except Exception as e: |
| st.warning(f"Gemma explanation unavailable: {e}") |
|
|
| st.markdown("### 🔥 Model Attention (GradCAM)") |
| with st.spinner("Generating GradCAM..."): |
| try: |
| heatmap, conv_name = make_gradcam(img, cnn_model) |
| st.caption(f"Layer: {conv_name}") |
|
|
| heatmap_resized = cv2.resize(heatmap, (IMG_SIZE, IMG_SIZE)) |
| heatmap_uint8 = np.uint8(255 * heatmap_resized) |
| heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) |
|
|
| img_np = np.array(image.resize((IMG_SIZE, IMG_SIZE))) |
| superimposed = cv2.addWeighted(img_np, 0.6, heatmap_colored, 0.4, 0) |
|
|
| st.image(superimposed, caption=f"GradCAM ({conv_name})", use_column_width=True) |
|
|
| except Exception as e: |
| st.warning(f"GradCAM unavailable: {e}") |
|
|
| else: |
| st.info("📤 Upload a tomato leaf image to start prediction") |