tomat / src /streamlit_app.py
Reyal's picture
Update src/streamlit_app.py
276211d verified
import streamlit as st
import numpy as np
from PIL import Image
import json
import tensorflow as tf
import keras
import cv2
import os
from transformers import AutoModelForCausalLM
from peft import PeftModel
# ======================
# CONFIG
# ======================
st.set_page_config(page_title="Tomato AI 🍅", layout="wide")
st.markdown("""
<style>
.big-title { font-size:42px; font-weight:bold; color:#2E8B57; }
.card { padding:20px; border-radius:15px; background-color:#f5f5f5; margin-bottom:15px; }
.result { font-size:26px; font-weight:bold; color:#ff4b4b; }
.sub { font-size:18px; color:#555; }
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="big-title">🍅 AI Tomato Disease Detection System</div>', unsafe_allow_html=True)
# ======================
# PATHS
# ======================
GEMMA_PATH = "/app/gemma-tomato-lora"
MODEL_PATH = "/app/src/best_model.h5"
JSON_PATH = "/app/src/class_indices.json"
# ======================
# TOKEN — BİR DƏFƏ OXUNUR, SİLİNMİR
# ======================
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
st.error("HF_TOKEN tapılmadı!")
st.stop()
# ======================
# LOAD CNN
# ======================
@st.cache_resource
def load_cnn():
model = keras.models.load_model(MODEL_PATH, compile=False)
return model
cnn_model = load_cnn()
# ======================
# LOAD CLASSES
# ======================
with open(JSON_PATH) as f:
class_indices = json.load(f)
class_names = [None] * len(class_indices)
for k, v in class_indices.items():
class_names[v] = k
# ======================
# PREPROCESS
# ======================
IMG_SIZE = 224
def preprocess(img):
img = img.convert("RGB")
img = img.resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img, dtype=np.float32)
return np.expand_dims(arr, axis=0)
# ======================
# GEMMA EXPLANATION
# ======================
@st.cache_resource
def load_llm(_token):
from transformers import GemmaTokenizer
tok = GemmaTokenizer(
vocab_file=f"{GEMMA_PATH}/tokenizer.model",
add_bos_token=True,
add_eos_token=False,
)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
device_map="auto",
torch_dtype="auto",
token=_token
)
model = PeftModel.from_pretrained(model, GEMMA_PATH)
return tok, model
tokenizer, gemma = load_llm(hf_token)
def explain(label):
prompt = f"""You are a plant disease expert. Explain the following tomato disease clearly and concisely in English.
Disease: {label}
Provide:
1. Cause
2. Symptoms
3. Prevention
4. Treatment advice for farmers
Answer:"""
inputs = tokenizer(prompt, return_tensors="pt").to(gemma.device)
out = gemma.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
repetition_penalty=1.3
)
full = tokenizer.decode(out[0], skip_special_tokens=True)
return full[len(prompt):]
# ======================
# GRADCAM
# ======================
def make_gradcam(img_array, model):
last_conv_name = None
search_model = model
for layer in model.layers:
if isinstance(layer, keras.Model):
search_model = layer
for inner_layer in layer.layers:
if isinstance(inner_layer, keras.Model):
search_model = inner_layer
break
break
for layer in reversed(search_model.layers):
layer_type = type(layer).__name__
if "Conv2D" in layer_type or "conv" in layer.name.lower():
last_conv_name = layer.name
break
if last_conv_name is None:
layer_names = [f"{l.name} ({type(l).__name__})" for l in search_model.layers[-20:]]
raise ValueError("Conv2D tapılmadı:\n" + "\n".join(layer_names))
try:
grad_model = keras.models.Model(
inputs=search_model.inputs,
outputs=[search_model.get_layer(last_conv_name).output, search_model.output]
)
except Exception:
grad_model = keras.models.Model(
inputs=model.inputs,
outputs=[search_model.get_layer(last_conv_name).output, model.output]
)
_ = grad_model(img_array)
with tf.GradientTape() as tape:
conv_outputs, preds = grad_model(img_array)
loss = tf.reduce_max(preds)
grads = tape.gradient(loss, conv_outputs)
pooled = tf.reduce_mean(grads, axis=(0, 1, 2))
heatmap = conv_outputs[0] @ pooled[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = np.maximum(heatmap.numpy(), 0)
max_val = np.max(heatmap)
if max_val > 0:
heatmap /= max_val
return heatmap, last_conv_name
# ======================
# UI
# ======================
uploaded_file = st.file_uploader("Upload tomato leaf image 🍅", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Input Image", use_column_width=True)
img = preprocess(image)
preds = cnn_model.predict(img, verbose=0)
idx = np.argmax(preds[0])
label = class_names[idx]
conf = float(preds[0][idx])
with col2:
st.markdown('<div class="card">', unsafe_allow_html=True)
st.markdown(f"<div class='result'>🔍 Prediction: {label}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='sub'>Confidence: {conf:.2%}</div>", unsafe_allow_html=True)
st.progress(conf)
st.markdown("#### Top 3 Predictions")
top3_idx = np.argsort(preds[0])[::-1][:3]
for i in top3_idx:
st.markdown(f"- **{class_names[i]}**: {preds[0][i]:.2%}")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown("### 🧠 AI Explanation (Gemma)")
with st.spinner("Generating explanation..."):
try:
explanation = explain(label)
st.info(explanation)
except Exception as e:
st.warning(f"Gemma explanation unavailable: {e}")
st.markdown("### 🔥 Model Attention (GradCAM)")
with st.spinner("Generating GradCAM..."):
try:
heatmap, conv_name = make_gradcam(img, cnn_model)
st.caption(f"Layer: {conv_name}")
heatmap_resized = cv2.resize(heatmap, (IMG_SIZE, IMG_SIZE))
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
img_np = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
superimposed = cv2.addWeighted(img_np, 0.6, heatmap_colored, 0.4, 0)
st.image(superimposed, caption=f"GradCAM ({conv_name})", use_column_width=True)
except Exception as e:
st.warning(f"GradCAM unavailable: {e}")
else:
st.info("📤 Upload a tomato leaf image to start prediction")