Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from io import BytesIO | |
| from ultralytics import YOLO | |
| #import ollama | |
| from datetime import datetime | |
| import tempfile | |
| import os | |
| import base64 | |
| import bcrypt | |
| import sqlite3 | |
| import time | |
| from kokoro import KPipeline | |
| import soundfile as sf | |
| from IPython.display import Audio | |
| import torch | |
| from googletrans import Translator | |
| from sentence_transformers import SentenceTransformer | |
| #from ragas.metrics import AnswerRelevancy, Faithfulness, AnswerCorrectness, ContextPrecision | |
| #from ragas import evaluate | |
| #from ragas.metrics import ( | |
| # answer_relevancy, | |
| # faithfulness, | |
| # answer_correctness, | |
| # context_precision | |
| #) | |
| from datasets import Dataset | |
| import pandas as pd | |
| import random | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from typing import List, Dict | |
| import asyncio | |
| from ragas.embeddings.base import BaseRagasEmbeddings | |
| from dotenv import load_dotenv | |
| from groq import Groq | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Initialize Groq client | |
| groq_client = Groq(api_key=os.environ['GROQ_API_KEY']) | |
| class SentenceTransformerEmbeddings(BaseRagasEmbeddings): | |
| """ | |
| A wrapper class to adapt SentenceTransformer to the BaseRagasEmbeddings interface. | |
| This class implements both synchronous and asynchronous embedding methods. | |
| """ | |
| def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.model = SentenceTransformer(model_name) | |
| def embed_query(self, text: str) -> List[float]: | |
| """ | |
| Embed a single query (text) using SentenceTransformer (synchronous). | |
| """ | |
| return self.model.encode(text).tolist() | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """ | |
| Embed a list of documents (texts) using SentenceTransformer (synchronous). | |
| """ | |
| return [self.model.encode(text).tolist() for text in texts] | |
| async def aembed_query(self, text: str) -> List[float]: | |
| """ | |
| Embed a single query (text) using SentenceTransformer (asynchronous). | |
| """ | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, self.embed_query, text) | |
| async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """ | |
| Embed a list of documents (texts) using SentenceTransformer (asynchronous). | |
| """ | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, self.embed_documents, texts) | |
| class RAGSystemVariants: | |
| def __init__(self): | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| async def baseline_rag(self, query, top_k=3): | |
| """Your current full RAG system""" | |
| chunks = retrieve_relevant_documents(query, [], top_k) | |
| context = "\n".join([chunk["text"] for chunk in chunks]) | |
| response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}") | |
| return response, context | |
| async def no_retrieval(self, query): | |
| """Generation only - no retrieval""" | |
| response = await generate_groq_response(query) | |
| return response, "" | |
| async def random_retrieval(self, query, top_k=3): | |
| """Random document selection instead of semantic retrieval""" | |
| try: | |
| all_docs = client.scroll(collection_name=collection_name, limit=100)[0] | |
| if len(all_docs) > 0: | |
| random_chunks = random.sample(all_docs, min(top_k, len(all_docs))) | |
| context = "\n".join([chunk.payload["text"] for chunk in random_chunks]) | |
| else: | |
| context = "" | |
| response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}") | |
| return response, context | |
| except Exception as e: | |
| st.error(f"Error in random retrieval: {e}") | |
| return "Error in random retrieval", "" | |
| async def different_top_k(self, query, top_k): | |
| """Test different top-k values""" | |
| chunks = retrieve_relevant_documents(query, [], top_k) | |
| context = "\n".join([chunk["text"] for chunk in chunks]) | |
| response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}") | |
| return response, context | |
| def create_test_dataset(limit=20): | |
| """Create a test dataset for RAGAS evaluation""" | |
| test_cases = [] | |
| try: | |
| conn = sqlite3.connect('./db/disease_knowledge_base.db') | |
| c = conn.cursor() | |
| c.execute("SELECT name, cause, symptoms, treatment FROM diseases LIMIT ?", (limit,)) | |
| diseases = c.fetchall() | |
| conn.close() | |
| for disease_name, cause, symptoms, treatment in diseases: | |
| questions_and_answers = [ | |
| (f"What causes {disease_name}?", cause), | |
| (f"What are the symptoms of {disease_name}?", symptoms), | |
| (f"How do I treat {disease_name}?", treatment), | |
| (f"Tell me about {disease_name}", f"Cause: {cause}\nSymptoms: {symptoms}\nTreatment: {treatment}"), | |
| ] | |
| for question, ground_truth in questions_and_answers: | |
| test_cases.append({ | |
| "question": question, | |
| "ground_truth": ground_truth, | |
| "disease": disease_name | |
| }) | |
| return test_cases[:limit] | |
| except Exception as e: | |
| st.error(f"Error creating test dataset: {e}") | |
| return [] | |
| async def run_ablation_study(progress_bar, status_text, max_questions=20): | |
| """Run comprehensive ablation study with progress tracking""" | |
| status_text.text("Creating test dataset...") | |
| test_cases = create_test_dataset(limit=max_questions) | |
| if not test_cases: | |
| st.error("No test cases created. Check your database connection.") | |
| return None | |
| rag_variants = RAGSystemVariants() | |
| experiments = { | |
| "Full_RAG_k3": lambda q: rag_variants.baseline_rag(q, top_k=3), | |
| "No_Retrieval": lambda q: rag_variants.no_retrieval(q), | |
| "Random_Retrieval": lambda q: rag_variants.random_retrieval(q, top_k=3), | |
| "RAG_k1": lambda q: rag_variants.different_top_k(q, top_k=1), | |
| "RAG_k5": lambda q: rag_variants.different_top_k(q, top_k=5), | |
| "RAG_k10": lambda q: rag_variants.different_top_k(q, top_k=10), | |
| } | |
| all_results = [] | |
| total_experiments = len(experiments) * len(test_cases) | |
| current_progress = 0 | |
| for exp_name, exp_func in experiments.items(): | |
| status_text.text(f"Running experiment: {exp_name}") | |
| questions = [] | |
| answers = [] | |
| contexts = [] | |
| ground_truths = [] | |
| for test_case in test_cases: | |
| try: | |
| answer, context = await exp_func(test_case["question"]) | |
| questions.append(test_case["question"]) | |
| answers.append(answer) | |
| contexts.append([context] if context else [""]) | |
| ground_truths.append(test_case["ground_truth"]) | |
| current_progress += 1 | |
| progress_bar.progress(current_progress / total_experiments) | |
| except Exception as e: | |
| st.error(f"Error in {exp_name}: {e}") | |
| continue | |
| exp_results = [] | |
| evaluator = LocalMetricsEvaluator() | |
| for q, a, c, gt in zip(questions, answers, contexts, ground_truths): | |
| context_str = c[0] if c and c[0] else "" | |
| metrics = { | |
| "question": q, | |
| "answer": a, | |
| "context": context_str, | |
| "ground_truth": gt, | |
| "experiment": exp_name, | |
| "answer_relevancy": evaluator.evaluate_answer_relevancy(q, a), | |
| "faithfulness": evaluator.evaluate_faithfulness(a, context_str) if context_str else 1.0, | |
| "answer_correctness": evaluator.evaluate_answer_correctness(a, gt), | |
| "context_precision": evaluator.evaluate_context_precision(q, context_str) if context_str else 0.0, | |
| "context_recall": evaluator.evaluate_context_recall(q, context_str, gt) if context_str else 0.0 | |
| } | |
| exp_results.append(metrics) | |
| all_results.extend(exp_results) | |
| return pd.DataFrame(all_results) | |
| def visualize_ablation_results(results_df): | |
| """Create interactive visualizations for ablation study results""" | |
| summary_stats = results_df.groupby('experiment').agg({ | |
| 'answer_relevancy': ['mean', 'std'], | |
| 'faithfulness': ['mean', 'std'], | |
| 'answer_correctness': ['mean', 'std'], | |
| 'context_precision': ['mean', 'std'], | |
| 'context_recall': ['mean', 'std'] | |
| }).round(3) | |
| summary_stats.columns = ['_'.join(col).strip() for col in summary_stats.columns.values] | |
| summary_stats = summary_stats.reset_index() | |
| metrics = ['answer_relevancy_mean', 'faithfulness_mean', 'answer_correctness_mean', | |
| 'context_precision_mean', 'context_recall_mean'] | |
| # Radar chart | |
| fig_radar = go.Figure() | |
| for _, row in summary_stats.iterrows(): | |
| fig_radar.add_trace(go.Scatterpolar( | |
| r=[row[metric] for metric in metrics], | |
| theta=[metric.replace('_mean', '').replace('_', ' ').title() for metric in metrics], | |
| fill='toself', | |
| name=row['experiment'] | |
| )) | |
| fig_radar.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 1] | |
| )), | |
| showlegend=True, | |
| title="RAGAS Metrics Comparison Across Experiments" | |
| ) | |
| # Bar chart comparison | |
| fig_bar = make_subplots( | |
| rows=2, cols=3, | |
| subplot_titles=[metric.replace('_mean', '').replace('_', ' ').title() for metric in metrics], | |
| ) | |
| for i, metric in enumerate(metrics): | |
| row = (i // 3) + 1 | |
| col = (i % 3) + 1 | |
| fig_bar.add_trace( | |
| go.Bar( | |
| x=summary_stats['experiment'], | |
| y=summary_stats[metric], | |
| error_y=dict(type='data', array=summary_stats[metric.replace('mean', 'std')]), | |
| name=metric.replace('_mean', '').replace('_', ' ').title(), | |
| showlegend=False | |
| ), | |
| row=row, col=col | |
| ) | |
| fig_bar.update_layout(height=800, title="Detailed Metrics Comparison") | |
| return fig_radar, fig_bar, summary_stats | |
| def render_research_page(): | |
| """Render the research/ablation study page""" | |
| st.title("🔬 RAG System Research Dashboard") | |
| st.markdown("Systematic evaluation and ablation study of the crop disease detection RAG system") | |
| # Initialize session state for results | |
| if 'ablation_results' not in st.session_state: | |
| st.session_state['ablation_results'] = None | |
| tabs = st.tabs(["Ablation Study", "Model Comparison", "Error Analysis", "Export Results"]) | |
| with tabs[0]: | |
| st.header("🧪 Ablation Study") | |
| st.write("This systematically evaluates different components of the RAG system.") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| max_questions = st.number_input("Number of test questions per experiment", | |
| min_value=5, max_value=50, value=20) | |
| with col2: | |
| selected_model_research = st.selectbox( | |
| "Select Model for Experiments", | |
| list(SUPPORTED_MODELS.keys()), | |
| key="research_model_select" | |
| ) | |
| if st.button("🚀 Start Ablation Study", type="primary"): | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| with st.spinner("Running ablation study..."): | |
| try: | |
| results_df = asyncio.run(run_ablation_study(progress_bar, status_text, max_questions)) | |
| if results_df is not None: | |
| st.session_state['ablation_results'] = results_df | |
| st.success("✅ Ablation study completed!") | |
| # Show summary statistics | |
| st.subheader("📊 Summary Statistics") | |
| summary_stats = results_df.groupby('experiment').agg({ | |
| 'answer_relevancy': 'mean', | |
| 'faithfulness': 'mean', | |
| 'answer_correctness': 'mean', | |
| 'context_precision': 'mean', | |
| 'context_recall': 'mean' | |
| }).round(3) | |
| st.dataframe(summary_stats, use_container_width=True) | |
| # Key insights | |
| best_overall = summary_stats.mean(axis=1).idxmax() | |
| st.success(f"🏆 **Best Overall Configuration:** {best_overall}") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| best_relevancy = summary_stats['answer_relevancy'].idxmax() | |
| st.metric("Best Answer Relevancy", best_relevancy, | |
| f"{summary_stats.loc[best_relevancy, 'answer_relevancy']:.3f}") | |
| with col2: | |
| best_faithfulness = summary_stats['faithfulness'].idxmax() | |
| st.metric("Best Faithfulness", best_faithfulness, | |
| f"{summary_stats.loc[best_faithfulness, 'faithfulness']:.3f}") | |
| with col3: | |
| best_correctness = summary_stats['answer_correctness'].idxmax() | |
| st.metric("Best Correctness", best_correctness, | |
| f"{summary_stats.loc[best_correctness, 'answer_correctness']:.3f}") | |
| # Create and display visualizations | |
| fig_radar, fig_bar, summary_stats_detailed = visualize_ablation_results(results_df) | |
| st.subheader("📈 Results Visualization") | |
| viz_tab1, viz_tab2, viz_tab3 = st.tabs(["Radar Chart", "Detailed Comparison", "Raw Data"]) | |
| with viz_tab1: | |
| st.plotly_chart(fig_radar, use_container_width=True) | |
| st.markdown("**Interpretation:** The radar chart shows the relative performance of each experiment across all RAGAS metrics. Larger areas indicate better overall performance.") | |
| with viz_tab2: | |
| st.plotly_chart(fig_bar, use_container_width=True) | |
| st.markdown("**Interpretation:** The bar charts show detailed performance with error bars indicating standard deviation across test cases.") | |
| with viz_tab3: | |
| st.dataframe(results_df, use_container_width=True) | |
| # Download options | |
| csv = results_df.to_csv(index=False) | |
| st.download_button( | |
| label="📥 Download Raw Results (CSV)", | |
| data=csv, | |
| file_name=f"ablation_study_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
| mime="text/csv" | |
| ) | |
| except Exception as e: | |
| st.error(f"❌ Error running ablation study: {str(e)}") | |
| st.exception(e) | |
| with tabs[1]: | |
| st.header("Model Comparison") | |
| st.write("Compare different LLM models on the same test dataset.") | |
| selected_models = st.multiselect( | |
| "Select models to compare", | |
| list(SUPPORTED_MODELS.keys()), | |
| default=list(SUPPORTED_MODELS.keys())[:2] | |
| ) | |
| num_questions_comp = st.number_input("Number of questions for comparison", | |
| min_value=5, max_value=30, value=10) | |
| if selected_models and st.button("🔍 Run Model Comparison"): | |
| st.info("Model comparison functionality can be extended here...") | |
| progress_bar_comp = st.progress(0) | |
| status_text_comp = st.empty() | |
| with st.spinner("Comparing models..."): | |
| # Create a simplified comparison focusing on generation quality | |
| test_cases = create_test_dataset(limit=num_questions_comp) | |
| comparison_results = [] | |
| total_comparisons = len(selected_models) * len(test_cases) | |
| current_progress_comp = 0 | |
| for model_name in selected_models: | |
| status_text_comp.text(f"Testing model: {model_name}") | |
| for test_case in test_cases: | |
| try: | |
| # Generate response with current model | |
| response = asyncio.run(generate_groq_response( | |
| test_case["question"], | |
| model_name=SUPPORTED_MODELS[model_name]["name"] | |
| )) | |
| # Evaluate | |
| evaluator = LocalMetricsEvaluator() | |
| comparison_results.append({ | |
| "model": model_name, | |
| "question": test_case["question"], | |
| "answer": response, | |
| "ground_truth": test_case["ground_truth"], | |
| "disease": test_case["disease"], | |
| "answer_relevancy": evaluator.evaluate_answer_relevancy(test_case["question"], response), | |
| "answer_correctness": evaluator.evaluate_answer_correctness(response, test_case["ground_truth"]) | |
| }) | |
| current_progress_comp += 1 | |
| progress_bar_comp.progress(current_progress_comp / total_comparisons) | |
| except Exception as e: | |
| st.error(f"Error testing {model_name}: {e}") | |
| continue | |
| if comparison_results: | |
| comp_df = pd.DataFrame(comparison_results) | |
| # Summary by model | |
| model_summary = comp_df.groupby('model').agg({ | |
| 'answer_relevancy': 'mean', | |
| 'answer_correctness': 'mean' | |
| }).round(3) | |
| st.subheader("📊 Model Performance Summary") | |
| st.dataframe(model_summary, use_container_width=True) | |
| # Visualization | |
| fig_model_comp = px.bar( | |
| model_summary.reset_index(), | |
| x='model', | |
| y=['answer_relevancy', 'answer_correctness'], | |
| title="Model Performance Comparison", | |
| barmode='group' | |
| ) | |
| st.plotly_chart(fig_model_comp, use_container_width=True) | |
| # Store results | |
| st.session_state['model_comparison_results'] = comp_df | |
| with tabs[2]: | |
| st.header("Error Analysis") | |
| st.write("Analyze failure cases and performance patterns.") | |
| if st.session_state['ablation_results'] is not None: | |
| results_df = st.session_state['ablation_results'] | |
| # Find worst performing cases | |
| st.subheader("Worst Performing Cases") | |
| worst_cases = results_df.nsmallest(10, 'answer_correctness')[['question', 'answer', 'ground_truth', 'experiment', 'answer_correctness']] | |
| st.dataframe(worst_cases, use_container_width=True) | |
| # Performance by experiment | |
| st.subheader("Performance Distribution") | |
| fig_box = px.box(results_df, x='experiment', y='answer_correctness', | |
| title="Answer Correctness Distribution by Experiment") | |
| st.plotly_chart(fig_box, use_container_width=True) | |
| else: | |
| st.info("Run an ablation study first to see error analysis.") | |
| with tabs[3]: | |
| st.header("Export Results") | |
| st.write("Export results for research papers and further analysis.") | |
| if st.session_state['ablation_results'] is not None: | |
| results_df = st.session_state['ablation_results'] | |
| # Generate summary report | |
| report = f""" | |
| # RAG System Ablation Study Report | |
| **Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| **Total Experiments:** {len(results_df['experiment'].unique())} | |
| **Total Test Cases:** {len(results_df)} | |
| ## Summary Statistics | |
| {results_df.groupby('experiment').agg({ | |
| 'answer_relevancy': ['mean', 'std'], | |
| 'faithfulness': ['mean', 'std'], | |
| 'answer_correctness': ['mean', 'std'], | |
| 'context_precision': ['mean', 'std'], | |
| 'context_recall': ['mean', 'std'] | |
| }).round(3).to_string()} | |
| ## Best Performing Configurations | |
| - **Best Answer Relevancy:** {results_df.groupby('experiment')['answer_relevancy'].mean().idxmax()} | |
| - **Best Faithfulness:** {results_df.groupby('experiment')['faithfulness'].mean().idxmax()} | |
| - **Best Answer Correctness:** {results_df.groupby('experiment')['answer_correctness'].mean().idxmax()} | |
| ## Recommendations | |
| Based on the ablation study results, we recommend... | |
| [Add your analysis here] | |
| """ | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.download_button( | |
| label="📄 Download Report (Markdown)", | |
| data=report, | |
| file_name=f"ablation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md", | |
| mime="text/markdown" | |
| ) | |
| with col2: | |
| csv_data = results_df.to_csv(index=False) | |
| st.download_button( | |
| label="📊 Download Data (CSV)", | |
| data=csv_data, | |
| file_name=f"ablation_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
| mime="text/csv" | |
| ) | |
| else: | |
| st.info("No results available for export. Run an ablation study first.") | |
| # Database setup | |
| conn = sqlite3.connect('users.db') | |
| c = conn.cursor() | |
| c.execute('''CREATE TABLE IF NOT EXISTS users | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE, | |
| password_hash TEXT)''') | |
| conn.commit() | |
| # Password hashing and verification | |
| def hash_password(password): | |
| return bcrypt.hashpw(password.encode(), bcrypt.gensalt()) | |
| def verify_password(password, hashed_password): | |
| return bcrypt.checkpw(password.encode(), hashed_password) | |
| # Add a user | |
| def add_user(username, password): | |
| # Check if username already exists | |
| c.execute("SELECT id FROM users WHERE username = ?", (username,)) | |
| result = c.fetchone() | |
| if result: | |
| return False # Username already exists | |
| # Hash the password and insert the new user | |
| password_hash = hash_password(password) | |
| c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)", | |
| (username, password_hash)) | |
| conn.commit() | |
| return True | |
| # Verify a user | |
| def verify_user(username, password): | |
| c.execute("SELECT password_hash FROM users WHERE username = ?", (username,)) | |
| result = c.fetchone() | |
| if result: | |
| return verify_password(password, result[0]) | |
| return False | |
| # Login and logout | |
| def login(username, password): | |
| if not username or not password: | |
| st.error("Username and password are required.") | |
| return False | |
| if verify_user(username, password): | |
| st.session_state['authenticated'] = True | |
| st.session_state['username'] = username | |
| st.session_state['last_activity'] = time.time() | |
| return True | |
| st.error("Invalid username or password.") | |
| return False | |
| def logout(): | |
| st.session_state['authenticated'] = False | |
| st.session_state['username'] = None | |
| # Add this at the top of your file | |
| def local_css(): | |
| st.markdown(""" | |
| <style> | |
| .stButton>button { | |
| width: 100%; | |
| border-radius: 5px; | |
| height: 3em; | |
| margin-top: 10px; | |
| } | |
| .auth-container { | |
| max-width: 400px; | |
| margin: auto; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| background-color: white; | |
| } | |
| .auth-title { | |
| text-align: center; | |
| font-size: 24px; | |
| margin-bottom: 20px; | |
| color: #1f1f1f; | |
| } | |
| .auth-subtitle { | |
| text-align: center; | |
| font-size: 16px; | |
| margin-bottom: 20px; | |
| color: #666; | |
| } | |
| .hero-section { | |
| text-align: center; | |
| padding: 40px 20px; | |
| background: linear-gradient(to right, #4f46e5, #3b82f6); | |
| color: white; | |
| margin-bottom: 30px; | |
| } | |
| .feature-container { | |
| max-width: 1200px; | |
| margin: auto; | |
| padding: 20px; | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); | |
| gap: 20px; | |
| margin-bottom: 40px; | |
| } | |
| .feature-card { | |
| background: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
| transition: transform 0.3s ease, box-shadow 0.3s ease; | |
| } | |
| .feature-card:hover { | |
| transform: scale(1.05); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
| } | |
| .feature-title { | |
| color: #1f1f1f; | |
| font-size: 18px; | |
| margin-bottom: 10px; | |
| font-weight: bold; | |
| } | |
| .feature-text { | |
| color: #666; | |
| font-size: 14px; | |
| } | |
| .divider { | |
| text-align: center; | |
| margin: 20px 0; | |
| position: relative; | |
| } | |
| .divider:before { | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 0; | |
| right: 0; | |
| height: 1px; | |
| background-color: #e0e0e0; | |
| z-index: -1; | |
| } | |
| .divider span { | |
| background-color: white; | |
| padding: 0 10px; | |
| color: #666; | |
| font-size: 14px; | |
| } | |
| @keyframes typing { | |
| 0% { | |
| width: 0; | |
| } | |
| 50% { | |
| width: 100%; | |
| } | |
| 60% { | |
| width: 100%; | |
| } | |
| 100% { | |
| width: 0; | |
| } | |
| } | |
| @keyframes blink { | |
| 50% { | |
| border-color: transparent; | |
| } | |
| } | |
| .hero-title{ | |
| display: inline-block; | |
| font-size: 2.5em; | |
| white-space: nowrap; | |
| overflow: hidden; | |
| border-right: 2px solid white; | |
| width: 0; | |
| animation: typing 6s steps(40, end) infinite, blink 0.5s step-end infinite; | |
| } | |
| .hero-section { | |
| text-align: center; | |
| padding: 40px 20px; | |
| background: linear-gradient(45deg, #4f46e5, #3b82f6); | |
| background-size: 300% 300%; | |
| animation: gradientShift 8s ease infinite; | |
| color: white; | |
| margin-bottom: 30px; | |
| opacity: 0; | |
| animation: fadeIn 2s ease-in-out forwards; | |
| } | |
| @keyframes fadeIn { | |
| from { | |
| opacity: 0; | |
| } | |
| to { | |
| opacity: 1; | |
| } | |
| } | |
| @keyframes gradientShift { | |
| 0% { | |
| background-position: 0% 50%; | |
| } | |
| 50% { | |
| background-position: 100% 50%; | |
| } | |
| 100% { | |
| background-position: 0% 50%; | |
| } | |
| } | |
| /*.feature-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 20px; | |
| position: relative; | |
| width: 100%; | |
| height: 300px; | |
| animation: rotate 20s linear infinite; /* Rotate the container */ | |
| } | |
| .feature-card { | |
| background: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
| transition: transform 0.3s ease, box-shadow 0.3s ease; | |
| flex-shrink: 0; | |
| width: 250px; | |
| } | |
| .feature-card:hover { | |
| transform: scale(1.1); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
| } | |
| @keyframes rotate { | |
| from { | |
| transform: rotate(0deg); | |
| } | |
| to { | |
| transform: rotate(-360deg); | |
| } | |
| */} | |
| /*.feature-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| overflow: hidden; | |
| position: relative; | |
| width: 100%; | |
| height: 300px; | |
| } | |
| .feature-track { | |
| display: flex; | |
| animation: circularMove 15s linear infinite; | |
| } | |
| .feature-card { | |
| flex: 0 0 300px; /* Fixed width for each card */ | |
| margin: 0 20px; | |
| background: white; | |
| color: #333; /* Text color */ | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
| text-align: center; /* Center-align the text */ | |
| overflow: hidden; /* Prevent overflow issues */ | |
| } | |
| .feature-card h3 { | |
| font-size: 1.2em; | |
| margin-bottom: 10px; | |
| text-align: center; | |
| } | |
| .feature-card p { | |
| font-size: 0.9em; | |
| line-height: 1.4; | |
| text-align: center; | |
| font-weight: bold; | |
| } | |
| .feature-card:hover { | |
| transform: scale(1.1); | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
| } | |
| @keyframes circularMove { | |
| 0% { | |
| transform: translateX(0); | |
| } | |
| 100% { | |
| transform: translateX(-100%); | |
| } | |
| */} | |
| .feature-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| height: 400px; | |
| perspective: 1000px; | |
| perspective-origin: 50% 50%; | |
| background: linear-gradient(to bottom, #1e293b, #0f172a); /* Dark blue gradient background */ | |
| overflow: hidden; | |
| position: relative; | |
| padding: 40px 0; | |
| } | |
| .feature-track { | |
| position: relative; | |
| width: 100%; | |
| height: 100%; | |
| display: flex; | |
| transform-style: preserve-3d; | |
| animation: carousel 15s linear infinite; | |
| } | |
| .feature-card { | |
| position: absolute; | |
| width: 300px; | |
| padding: 50px; | |
| background: white; | |
| border-radius: 15px; | |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3); /* Enhanced shadow for better contrast */ | |
| backface-visibility: hidden; | |
| transform-origin: center center; | |
| transition: all 0.5s ease; | |
| } | |
| .feature-card h3 { | |
| color: #1e293b; | |
| font-size: 1.5em; | |
| margin-bottom: 1rem; | |
| font-weight: bold; | |
| } | |
| .feature-card p { | |
| color: #475569; | |
| line-height: 1.6; | |
| } | |
| /* Position and animate cards */ | |
| .feature-card:nth-child(1) { | |
| transform: rotateY(0deg) translateZ(400px) translateX(0px); | |
| } | |
| .feature-card:nth-child(2) { | |
| transform: rotateY(60deg) translateZ(400px) translateX(0px); | |
| } | |
| .feature-card:nth-child(3) { | |
| transform: rotateY(120deg) translateZ(400px) translateX(0px); | |
| } | |
| .feature-card:nth-child(4) { | |
| transform: rotateY(180deg) translateZ(400px) translateX(0px); | |
| } | |
| .feature-card:nth-child(5) { | |
| transform: rotateY(240deg) translateZ(400px) translateX(0px); | |
| } | |
| .feature-card:nth-child(6) { | |
| transform: rotateY(300deg) translateZ(400px) translateX(0px); | |
| } | |
| @keyframes carousel { | |
| 0% { | |
| transform: translateZ(-400px) rotateY(0deg); | |
| } | |
| 100% { | |
| transform: translateZ(-400px) rotateY(-360deg); | |
| } | |
| } | |
| /* Enhanced hover effect with glow */ | |
| .feature-card:hover { | |
| transform: scale(1.1) translateZ(450px); | |
| box-shadow: 0 8px 30px rgba(255, 255, 255, 0.1); /* Glowing effect */ | |
| z-index: 1; | |
| } | |
| /* Gradient overlays for depth effect */ | |
| .feature-container::before, | |
| .feature-container::after { | |
| content: ''; | |
| position: absolute; | |
| width: 100%; | |
| height: 100px; | |
| z-index: 2; | |
| pointer-events: none; | |
| } | |
| .feature-container::before { | |
| top: 0; | |
| background: linear-gradient(to bottom, #1e293b, rgba(30, 41, 59, 0)); | |
| } | |
| .feature-container::after { | |
| bottom: 0; | |
| background: linear-gradient(to top, #1e293b, rgba(30, 41, 59, 0)); | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Check session expiry | |
| if 'authenticated' in st.session_state and st.session_state['authenticated']: | |
| if time.time() - st.session_state['last_activity'] > 1800: # 30 minutes | |
| logout() | |
| st.rerun() | |
| st.session_state['last_activity'] = time.time() | |
| # Initialize session state for registration form visibility | |
| if 'show_register_form' not in st.session_state: | |
| st.session_state['show_register_form'] = False | |
| # Replace your login/registration section with this: | |
| if 'authenticated' not in st.session_state or not st.session_state['authenticated']: | |
| local_css() | |
| # Landing page hero section | |
| st.markdown(""" | |
| <div class="hero-section"> | |
| <h1 class="hero-title" style="font-size: 2.5em; margin-bottom: 20px;">Crop Disease Detection System</h1> | |
| <p style="font-size: 1.2em; max-width: 800px; margin: 0 auto;"> | |
| An advanced AI-powered system that helps farmers and agricultural experts identify and manage crop diseases effectively | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Features section using Streamlit columns | |
| st.subheader("Key Features") | |
| col1, col2, col3 = st.columns(3) | |
| st.markdown(""" | |
| <div class="feature-container"> | |
| <div class="feature-track"> | |
| <div class="feature-card"> | |
| <h3>🔍 Instant Detection</h3> | |
| <p>Upload images of your crops and get immediate disease detection results using state-of-the-art AI technology.</p> | |
| </div> | |
| <div class="feature-card"> | |
| <h3>💡 Expert Analysis</h3> | |
| <p>Receive detailed analysis and recommendations from our plant pathology expert system.</p> | |
| </div> | |
| <div class="feature-card"> | |
| <h3>📊 Detailed Reports</h3> | |
| <p>Generate comprehensive reports with treatment recommendations and preventive measures.</p> | |
| </div> | |
| <div class="feature-card"> | |
| <h3>🔍 Instant Detection</h3> | |
| <p>Upload images of your crops and get immediate disease detection results using state-of-the-art AI technology.</p> | |
| </div> | |
| <div class="feature-card"> | |
| <h3>💡 Expert Analysis</h3> | |
| <p>Receive detailed analysis and recommendations from our plant pathology expert system.</p> | |
| </div> | |
| <div class="feature-card"> | |
| <h3>📊 Detailed Reports</h3> | |
| <p>Generate comprehensive reports with treatment recommendations and preventive measures.</p> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Crop carousel section | |
| st.markdown(""" | |
| <div class="crop-carousel-container"> | |
| <div class="crop-carousel-track"> | |
| <div class="crop-card"> | |
| <img src="https://github.com/ROBERT-ADDO-ASANTE-DARKO/AI-powered-crop-disease-detection/blob/main/images/b034333ddcc732299d45abf753f3fa71f6ff48ffa3338bfecd615bc2.jpg?raw=true" alt="Crop 1"> | |
| <h4>Corn Leaf Blight</h4> | |
| <p>Corn leaf blight is a fungal disease caused primarily by Exserohilum turcicum (Northern corn leaf blight) and Bipolaris maydis (Southern corn leaf blight).</p> | |
| </div> | |
| <div class="crop-card"> | |
| <img src="https://github.com/ROBERT-ADDO-ASANTE-DARKO/AI-powered-crop-disease-detection/blob/main/images/apple.jpg?raw=true" alt="Crop 2"> | |
| <h4>Apple Scab Leaf</h4> | |
| <p>Apple scab is a fungal disease caused by Venturia inaequalis. It primarily affects apple and crabapple trees.</p> | |
| </div> | |
| <div class="crop-card"> | |
| <img src="https://github.com/ROBERT-ADDO-ASANTE-DARKO/AI-powered-crop-disease-detection/blob/main/images/tomato.jpg?raw=true" alt="Crop 3"> | |
| <h4>Tomato Leaf Late Blight</h4> | |
| <p>Late blight of tomato is caused by the oomycete pathogen Phytophthora infestans. It is characterized by dark, water-soaked lesions on leaves, stems, and fruit.</p> | |
| </div> | |
| <div class="crop-card"> | |
| <img src="https://github.com/ROBERT-ADDO-ASANTE-DARKO/AI-powered-crop-disease-detection/blob/main/images/918d1d7a3dda5ce8fbdabf92e5bf38f104efd129ee09adcc6d1ad46c.jpg?raw=true" alt="Crop 4"> | |
| <h4>Tomato Leaf Yellow Virus</h4> | |
| <p>Tomato leaf yellow virus (often referred to as Tomato yellow leaf curl virus, or TYLCV) is a viral disease transmitted by whiteflies. It causes yellowing and curling of tomato leaves.</p> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown(""" | |
| <style> | |
| .crop-carousel-container { | |
| width: 100%; | |
| max-width: 800px; | |
| margin: auto; | |
| overflow: hidden; | |
| position: relative; | |
| } | |
| .crop-carousel-track { | |
| display: flex; | |
| animation: moveLeft 20s linear infinite; /* Move right to left */ | |
| } | |
| .crop-card { | |
| flex: 0 0 300px; | |
| margin: 0 20px; | |
| background: white; | |
| color: #333; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); | |
| text-align: center; | |
| overflow: hidden; | |
| } | |
| .crop-card img { | |
| width: 100%; | |
| height: 150px; | |
| object-fit: cover; | |
| border-radius: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .crop-card h4 { | |
| font-size: 1.2em; | |
| margin: 10px 0; | |
| } | |
| .crop-card p { | |
| font-size: 0.9em; | |
| line-height: 1.4; | |
| color: #555; | |
| } | |
| @keyframes moveLeft { | |
| 0% { | |
| transform: translateX(100%); | |
| } | |
| 100% { | |
| transform: translateX(-100%); | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Add some spacing | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Authentication container | |
| st.markdown('<div class="auth-container">', unsafe_allow_html=True) | |
| # Initialize password reset state | |
| if 'show_reset_form' not in st.session_state: | |
| st.session_state['show_reset_form'] = False | |
| # Update password function | |
| def update_password(username, new_password): | |
| conn = sqlite3.connect('users.db') | |
| c = conn.cursor() | |
| # Check if username exists | |
| c.execute("SELECT id FROM users WHERE username = ?", (username,)) | |
| if not c.fetchone(): | |
| return False | |
| # Update password | |
| password_hash = bcrypt.hashpw(new_password.encode(), bcrypt.gensalt()) | |
| c.execute("UPDATE users SET password_hash = ? WHERE username = ?", | |
| (password_hash, username)) | |
| conn.commit() | |
| conn.close() | |
| return True | |
| # Update the authentication container section | |
| if not st.session_state.get('authenticated', False): | |
| st.markdown('<div class="auth-container">', unsafe_allow_html=True) | |
| # Reset Password Form | |
| if st.session_state.get('show_reset_form', False): | |
| st.markdown('<h1 class="auth-title">Reset Password</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="auth-subtitle">Enter your username and new password</p>', unsafe_allow_html=True) | |
| with st.form("reset_form"): | |
| username = st.text_input("Username") | |
| new_password = st.text_input("New Password", type="password") | |
| confirm_password = st.text_input("Confirm Password", type="password") | |
| submit = st.form_submit_button("Reset Password") | |
| if submit: | |
| if not username or not new_password or not confirm_password: | |
| st.error("All fields are required.") | |
| elif new_password != confirm_password: | |
| st.error("Passwords do not match.") | |
| elif update_password(username, new_password): | |
| st.success("Password updated successfully!") | |
| st.session_state['show_reset_form'] = False | |
| time.sleep(1) | |
| st.rerun() | |
| else: | |
| st.error("Username not found.") | |
| if st.button("Back to Login"): | |
| st.session_state['show_reset_form'] = False | |
| st.rerun() | |
| # Registration Form | |
| elif st.session_state.get('show_register_form', False): | |
| st.markdown('<h1 class="auth-title">Create Account</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="auth-subtitle">Sign up to get started</p>', unsafe_allow_html=True) | |
| with st.form("register_form"): | |
| new_username = st.text_input("Username") | |
| new_password = st.text_input("Password", type="password") | |
| submit_button = st.form_submit_button("Create Account") | |
| if submit_button: | |
| if new_username and new_password: | |
| if add_user(new_username, new_password): | |
| st.success("Account created successfully!") | |
| st.session_state['show_register_form'] = False | |
| time.sleep(1) | |
| st.rerun() | |
| else: | |
| st.error("Username already exists.") | |
| else: | |
| st.error("Username and password are required.") | |
| st.markdown('<div class="divider"><span>OR</span></div>', unsafe_allow_html=True) | |
| if st.button("Back to Login"): | |
| st.session_state['show_register_form'] = False | |
| st.rerun() | |
| # Login Form (default) | |
| else: | |
| st.markdown('<h1 class="auth-title">Welcome Back</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="auth-subtitle">Sign in to your account</p>', unsafe_allow_html=True) | |
| with st.form("login_form"): | |
| username = st.text_input("Username") | |
| password = st.text_input("Password", type="password") | |
| cols = st.columns([1, 1]) | |
| submit_button = cols[0].form_submit_button("Sign In") | |
| forgot_password = cols[1].form_submit_button("Forgot Password?") | |
| if submit_button: | |
| if login(username, password): | |
| st.success("Logged in successfully!") | |
| time.sleep(1) | |
| st.rerun() | |
| elif forgot_password: | |
| st.session_state['show_reset_form'] = True | |
| st.rerun() | |
| st.markdown('<div class="divider"><span>OR</span></div>', unsafe_allow_html=True) | |
| if st.button("Create New Account"): | |
| st.session_state['show_register_form'] = True | |
| st.rerun() | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Update the footer section (replace the existing footer with this) | |
| st.markdown(""" | |
| <div style="background: linear-gradient(to right, #1e293b, #334155); color: white; padding: 40px 0; margin-top: 40px;"> | |
| <div style="max-width: 1200px; margin: auto; padding: 0 20px;"> | |
| <div style="display: flex; flex-wrap: wrap; justify-content: space-between; gap: 40px;"> | |
| <!-- About Section --> | |
| <div style="flex: 1; min-width: 250px;"> | |
| <h3 style="color: #60a5fa; font-size: 1.5em; margin-bottom: 20px;">About Our Platform</h3> | |
| <p style="color: #e2e8f0; line-height: 1.6; margin-bottom: 20px;"> | |
| Our AI-powered platform revolutionizes crop disease detection and management. | |
| We combine cutting-edge technology with agricultural expertise to protect your crops | |
| and maximize your yield. | |
| </p> | |
| </div> | |
| <div style="flex: 1; min-width: 250px;"> | |
| <h3 style="color: #60a5fa; font-size: 1.5em; margin-bottom: 20px;">Key Features</h3> | |
| <ul style="list-style: none; padding: 0; color: #e2e8f0;"> | |
| <li style="margin-bottom: 10px; display: flex; align-items: center;"> | |
| <span style="color: #60a5fa; margin-right: 10px;">✓</span> Real-time Disease Detection | |
| </li> | |
| <li style="margin-bottom: 10px; display: flex; align-items: center;"> | |
| <span style="color: #60a5fa; margin-right: 10px;">✓</span> Multi-language Support | |
| </li> | |
| <li style="margin-bottom: 10px; display: flex; align-items: center;"> | |
| <span style="color: #60a5fa; margin-right: 10px;">✓</span> Expert Analysis Reports | |
| </li> | |
| <li style="margin-bottom: 10px; display: flex; align-items: center;"> | |
| <span style="color: #60a5fa; margin-right: 10px;">✓</span> Treatment Recommendations | |
| </li> | |
| </ul> | |
| </div> | |
| <div style="flex: 1; min-width: 250px;"> | |
| <h3 style="color: #60a5fa; font-size: 1.5em; margin-bottom: 20px;">Contact Us</h3> | |
| <p style="color: #e2e8f0; line-height: 1.6; margin-bottom: 10px;"> | |
| <span style="color: #60a5fa;">Email:</span> support@crophealth.ai | |
| </p> | |
| <p style="color: #e2e8f0; line-height: 1.6; margin-bottom: 20px;"> | |
| <span style="color: #60a5fa;">Phone:</span> +1 (234) 567-8900 | |
| </p> | |
| <div style="display: flex; gap: 15px; margin-top: 20px;"> | |
| <a href="#" style="color: #60a5fa; text-decoration: none; font-size: 1.2em;"> | |
| <span>📱</span> | |
| </a> | |
| <a href="#" style="color: #60a5fa; text-decoration: none; font-size: 1.2em;"> | |
| <span>💬</span> | |
| </a> | |
| <a href="#" style="color: #60a5fa; text-decoration: none; font-size: 1.2em;"> | |
| <span>📨</span> | |
| </a> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="border-top: 1px solid #4b5563; margin-top: 40px; padding-top: 20px; text-align: center;"> | |
| <p style="color: #e2e8f0; font-size: 0.9em;"> | |
| © 2025 Crop Disease Detection System. All rights reserved. | |
| </p> | |
| <div style="margin-top: 10px;"> | |
| <a href="#" style="color: #e2e8f0; text-decoration: none; margin: 0 10px; font-size: 0.9em;">Privacy Policy</a> | |
| <a href="#" style="color: #e2e8f0; text-decoration: none; margin: 0 10px; font-size: 0.9em;">Terms of Service</a> | |
| <a href="#" style="color: #e2e8f0; text-decoration: none; margin: 0 10px; font-size: 0.9em;">FAQ</a> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.stop() | |
| # Update database schema to include comments | |
| def setup_feedback_db(): | |
| conn = sqlite3.connect('customer_feedback.db') | |
| c = conn.cursor() | |
| c.execute('''CREATE TABLE IF NOT EXISTS customer_feedback | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| question TEXT, | |
| response TEXT, | |
| feedback_type TEXT, | |
| comment_type TEXT, | |
| custom_comment TEXT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') | |
| conn.commit() | |
| return conn, c | |
| def save_feedback(question, response, feedback_type, comment_type=None, custom_comment=None): | |
| conn, c = setup_feedback_db() | |
| try: | |
| c.execute("""INSERT INTO customer_feedback | |
| (question, response, feedback_type, comment_type, custom_comment) | |
| VALUES (?, ?, ?, ?, ?)""", | |
| (question, response, feedback_type, comment_type, custom_comment)) | |
| conn.commit() | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving feedback: {e}") | |
| return False | |
| finally: | |
| conn.close() | |
| # Update the conversation display section | |
| def display_feedback_buttons(file_id, index, question, response): | |
| # Suggested comments | |
| SUGGESTED_COMMENTS = [ | |
| "Inaccurate information", | |
| "Unclear explanation", | |
| "Missing details", | |
| "Not relevant to question", | |
| "Technical error", | |
| "Other" | |
| ] | |
| # Initialize session state for feedback if it doesn't exist | |
| if f"feedback_{file_id}_{index}" not in st.session_state: | |
| st.session_state[f"feedback_{file_id}_{index}"] = { | |
| "feedback_type": None, # Stores "👍" or "👎" | |
| "comment": None, # Stores the user's comment | |
| "submitted": False # Tracks whether feedback has been submitted | |
| } | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| if st.button("👍", key=f"helpful_{file_id}_{index}"): | |
| # Save positive feedback immediately | |
| save_feedback(question, response, "👍") | |
| st.success("Feedback saved!") | |
| # Update session state to indicate feedback has been submitted | |
| st.session_state[f"feedback_{file_id}_{index}"]["submitted"] = True | |
| return | |
| with col2: | |
| if st.button("👎", key=f"not_helpful_{file_id}_{index}"): | |
| # Store the feedback type in session state | |
| st.session_state[f"feedback_{file_id}_{index}"]["feedback_type"] = "👎" | |
| # Check if feedback_type is "👎" before showing the comment input field | |
| if st.session_state[f"feedback_{file_id}_{index}"].get("feedback_type") == "👎": | |
| # Display suggested comments in a dropdown menu | |
| selected_comment = st.selectbox( | |
| "What was the issue?", | |
| options=SUGGESTED_COMMENTS, | |
| key=f"suggested_comment_{file_id}_{index}" | |
| ) | |
| # If the user selects "Other", allow them to provide a custom comment | |
| custom_comment = None | |
| if selected_comment == "Other": | |
| custom_comment = st.text_area( | |
| "Please describe the issue:", | |
| key=f"custom_comment_{file_id}_{index}" | |
| ) | |
| # Submit Feedback button | |
| if st.button("Submit Feedback", key=f"submit_{file_id}_{index}"): | |
| # Save feedback to the database | |
| save_feedback( | |
| question, | |
| response, | |
| st.session_state[f"feedback_{file_id}_{index}"]["feedback_type"], | |
| custom_comment if selected_comment == "Other" else selected_comment | |
| ) | |
| st.success("Thank you for your feedback!") | |
| # Update session state to indicate feedback has been submitted | |
| st.session_state[f"feedback_{file_id}_{index}"]["submitted"] = True | |
| return | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| # Initialize Qdrant client (local instance) | |
| client = QdrantClient( | |
| url="https://f8a5b65d-191d-4a67-8536-ffd96c2f49c6.us-east4-0.gcp.cloud.qdrant.io:6333", | |
| api_key=os.environ['QDRANT_API_KEY'] | |
| ) | |
| # Collection name | |
| collection_name = "crop_disease_embeddings" | |
| # Check if the collection already exists | |
| existing_collections = client.get_collections() | |
| collection_names = [col.name for col in existing_collections.collections] | |
| if collection_name not in collection_names: | |
| # Create the collection if it doesn't exist | |
| client.recreate_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams(size=384, distance=Distance.COSINE) | |
| ) | |
| st.write(f"Created new collection: {collection_name}") | |
| else: | |
| st.write(f"Collection {collection_name} already exists. Skipping creation.") | |
| # Load a pre-trained model (e.g., 'all-MiniLM-L6-v2') | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Read the text file | |
| with open("./docs/disease_docs.txt", "r") as file: | |
| text = file.read() | |
| # Split the text into paragraphs (assuming paragraphs are separated by double newlines) | |
| #paragraphs = text.split("\n\n") | |
| # Alternatively, split into sentences (using a simple approach) | |
| import re | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| # Create documents from paragraphs or sentences | |
| documents = [] | |
| for idx, chunk in enumerate(sentences): # or sentences | |
| documents.append({ | |
| "id": idx + 1, # Unique ID for each chunk | |
| "text": chunk.strip(), # Remove leading/trailing whitespace | |
| "metadata": { | |
| "source": "docs/disease_docs.txt", # Add any relevant metadata | |
| "chunk_number": idx + 1 | |
| } | |
| }) | |
| # Generate embeddings and index them | |
| for doc in documents: | |
| embedding = embedding_model.encode(doc["text"]) | |
| client.upsert( | |
| collection_name=collection_name, | |
| points=[ | |
| { | |
| "id": doc["id"], | |
| "vector": embedding.tolist(), | |
| "payload": { | |
| "text": doc["text"], # Include the text in the payload | |
| "source": doc["metadata"]["source"], | |
| "chunk_number": doc["metadata"]["chunk_number"] | |
| } | |
| } | |
| ] | |
| ) | |
| def retrieve_relevant_documents(query, detected_diseases, top_k=3): | |
| """ | |
| Retrieve relevant documents based on the user's query and detected diseases. | |
| """ | |
| # Combine the user's query with the detected diseases | |
| combined_query = f"{query} {', '.join(detected_diseases)}" | |
| # Generate embeddings for the combined query | |
| query_embedding = embedding_model.encode(combined_query) | |
| # Search the Qdrant database | |
| search_result = client.search( | |
| collection_name=collection_name, | |
| query_vector=query_embedding.tolist(), | |
| limit=top_k | |
| ) | |
| # Log the retrieved chunks | |
| #st.write("Retrieved relevant chunks:") | |
| #for chunk in search_result: | |
| # st.write(f"- {chunk.payload['text']}") | |
| return [hit.payload for hit in search_result] | |
| def get_reference_answer(disease_name): | |
| """ | |
| Retrieve the reference answer (cause, symptoms, and treatment) for a specific disease from the SQLite database. | |
| """ | |
| conn = sqlite3.connect('./db/disease_knowledge_base.db') # Replace with your database path | |
| c = conn.cursor() | |
| # Query the database for cause, symptoms, and treatment | |
| c.execute("SELECT cause, symptoms, treatment FROM diseases WHERE name = ?", (disease_name,)) | |
| result = c.fetchone() | |
| conn.close() | |
| if result: | |
| # Format the result into a structured response | |
| cause, symptoms, treatment = result | |
| reference_answer = ( | |
| f"Cause: {cause}\n" | |
| f"Symptoms: {symptoms}\n" | |
| f"Treatment: {treatment}" | |
| ) | |
| return reference_answer # Return the formatted reference answer | |
| else: | |
| return None # Return None if no matching disease is found | |
| def filter_relevant_chunks(chunks, detected_diseases): | |
| """ | |
| Filter retrieved chunks to include only those relevant to the detected diseases. | |
| Args: | |
| chunks (list): List of retrieved chunks (each chunk is a dictionary with a "text" key). | |
| detected_diseases (list): List of detected disease names. | |
| Returns: | |
| list: Filtered list of chunks relevant to the detected diseases. | |
| """ | |
| filtered_chunks = [] | |
| for chunk in chunks: | |
| # Check if the chunk text contains any of the detected diseases | |
| if any(disease.lower() in chunk["text"].lower() for disease in detected_diseases): | |
| filtered_chunks.append(chunk) | |
| return filtered_chunks | |
| async def generate_rag_response(query, conversation_history=None, reference_answer=None): | |
| """ | |
| Generate a response using RAG and evaluate it using RAGAS metrics. | |
| """ | |
| # Retrieve relevant chunks | |
| relevant_chunks = retrieve_relevant_documents(query, detected_classes) | |
| # Filter the retrieved chunks to include only those relevant to the detected diseases | |
| filtered_chunks = filter_relevant_chunks(relevant_chunks, detected_classes) | |
| # Build context from filtered chunks | |
| context = "\n".join([chunk["text"] for chunk in filtered_chunks]) | |
| # Generate response using Ollama | |
| response = await generate_groq_response( | |
| f"Context: {context}\n\nQuestion: {query}", | |
| model_name=selected_model, | |
| conversation_history=conversation_history | |
| ) | |
| # Evaluate using the new local implementation | |
| ragas_result = evaluate_ragas(query, response, context, reference_answer) | |
| # Evaluate the generation-only system | |
| generation_only_result = evaluate_generation_only(query, reference_answer) | |
| print("Full RAG System Results:") | |
| print(f"Answer Relevancy: {ragas_result['answer_relevancy']:.2f}") | |
| print(f"Faithfulness: {ragas_result['faithfulness']:.2f}") | |
| print(f"Answer Correctness: {ragas_result['answer_correctness']:.2f}") | |
| print(f"Context Precision: {ragas_result['context_precision']:.2f}") | |
| print(f"Context Recall: {ragas_result['context_recall']:.2f}") | |
| print("Generation-Only System Results:") | |
| print(f"Answer Relevancy: {generation_only_result['answer_relevancy']:.2f}") | |
| print(f"Faithfulness: {generation_only_result['faithfulness']:.2f}") | |
| if generation_only_result['answer_correctness'] is not None: | |
| print(f"Answer Correctness: {generation_only_result['answer_correctness']:.2f}") | |
| else: | |
| print("Answer Correctness: N/A (No reference answer provided)") | |
| # Display metrics | |
| st.markdown(format_evaluation_results(ragas_result)) | |
| return response, filtered_chunks, ragas_result | |
| def generate_answer_without_retrieval(query, model_name="llama2"): | |
| """ | |
| Generate an answer using only the LLM (no retrieval). | |
| """ | |
| response = asyncio.run(generate_groq_response(query, model_name=model_name)) | |
| return response | |
| def evaluate_generation_only(query, reference_answer=None): | |
| """ | |
| Evaluate the generation-only system using RAGAS metrics. | |
| """ | |
| # Generate answer without retrieval | |
| response = generate_answer_without_retrieval(query) | |
| # Evaluate metrics that don't require context | |
| evaluator = LocalMetricsEvaluator() | |
| answer_relevancy = evaluator.evaluate_answer_relevancy(query, response) | |
| faithfulness = 1.0 # No context to compare, assume perfect faithfulness | |
| # Evaluate answer correctness only if reference_answer is provided | |
| answer_correctness = None | |
| if reference_answer: | |
| answer_correctness = evaluator.evaluate_answer_correctness(response, reference_answer) | |
| return { | |
| "answer_relevancy": answer_relevancy, | |
| "faithfulness": faithfulness, | |
| "answer_correctness": answer_correctness | |
| } | |
| class LocalMetricsEvaluator: | |
| def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.embeddings = SentenceTransformer(embedding_model) | |
| def calculate_semantic_similarity(self, text1: str, text2: str) -> float: | |
| # Calculate embeddings | |
| emb1 = self.embeddings.encode(text1) | |
| emb2 = self.embeddings.encode(text2) | |
| # Calculate cosine similarity | |
| similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) | |
| return float(similarity) | |
| def evaluate_answer_relevancy(self, question: str, answer: str) -> float: | |
| return self.calculate_semantic_similarity(question, answer) | |
| def evaluate_faithfulness(self, answer: str, context: str) -> float: | |
| return self.calculate_semantic_similarity(answer, context) | |
| def evaluate_answer_correctness(self, answer: str, reference: str) -> float: | |
| return self.calculate_semantic_similarity(answer, reference) | |
| def evaluate_context_precision(self, question: str, context: str) -> float: | |
| return self.calculate_semantic_similarity(question, context) | |
| def evaluate_context_recall(self, question: str, retrieved_context: str, reference: str) -> float: | |
| """ | |
| Evaluate context recall by comparing the retrieved context to the ground truth context. | |
| Args: | |
| question (str): The user's question. | |
| retrieved_context (str): The context retrieved by the RAG system. | |
| ground_truth_context (str): The ground truth context (relevant information). | |
| Returns: | |
| float: A score between 0 and 1, representing how well the retrieved context covers the ground truth context. | |
| """ | |
| # Calculate embeddings for the retrieved context and ground truth context | |
| retrieved_embedding = self.embeddings.encode(retrieved_context) | |
| ground_truth_embedding = self.embeddings.encode(reference) | |
| # Calculate cosine similarity between the retrieved context and ground truth context | |
| recall_score = np.dot(retrieved_embedding, ground_truth_embedding) / ( | |
| np.linalg.norm(retrieved_embedding) * np.linalg.norm(ground_truth_embedding) | |
| ) | |
| return float(recall_score) | |
| def evaluate_ragas(query: str, response: str, context: str, reference_answer: str = None): | |
| """ | |
| Evaluate the RAG system using local embeddings instead of OpenAI. | |
| Args: | |
| query (str): The user's question | |
| response (str): The generated response | |
| context (str): The context used to generate the response | |
| reference_answer (str, optional): Ground truth answer | |
| Returns: | |
| dict: Dictionary containing evaluation metrics | |
| """ | |
| # Initialize evaluator | |
| evaluator = LocalMetricsEvaluator() | |
| # If no reference answer is provided, use the response | |
| if reference_answer is None: | |
| reference_answer = response | |
| # Calculate metrics | |
| metrics = { | |
| "answer_relevancy": evaluator.evaluate_answer_relevancy(query, response), | |
| "faithfulness": evaluator.evaluate_faithfulness(response, context), | |
| "answer_correctness": evaluator.evaluate_answer_correctness(response, reference_answer), | |
| "context_precision": evaluator.evaluate_context_precision(query, context), | |
| "context_recall": evaluator.evaluate_context_recall(query, context, reference_answer) | |
| } | |
| return metrics | |
| def format_evaluation_results(metrics: Dict[str, float]) -> str: | |
| """Format the evaluation results for display""" | |
| return "\n".join([ | |
| f"📊 RAGAS Evaluation Results:", | |
| f"• Answer Relevancy: {metrics['answer_relevancy']:.3f}", | |
| f"• Faithfulness: {metrics['faithfulness']:.3f}", | |
| f"• Answer Correctness: {metrics['answer_correctness']:.3f}", | |
| f"• Context Precision: {metrics['context_precision']:.3f}", | |
| f"• Context Recall: {metrics['context_recall']:.3f}" | |
| ]) | |
| # Model configuration | |
| SUPPORTED_MODELS = { | |
| "deepseek-r1-distill-llama-70b": { | |
| "name": "deepseek-r1-distill-llama-70b", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "mistral-saba-24b": { | |
| "name": "mixtral-8x7b-32768", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "qwen/qwen3-32b": { | |
| "name": "qwen/qwen3-32b", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "gemma2-9b-it": { | |
| "name": "gemma2-9b-it", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "llama-3.1-8b-instant": { | |
| "name": "llama-3.1-8b-instant", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "llama-3.3-70b-versatile": { | |
| "name": "llama-3.3-70b-versatile", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| }, | |
| "llama3-70b-8192": { | |
| "name": "llama3-70b-8192", | |
| "system_prompt": "You are a helpful plant pathology expert assistant.", | |
| "supports_vision": False | |
| } | |
| } | |
| # Initialize session state for conversation history if it doesn't exist | |
| if 'conversation_history' not in st.session_state: | |
| st.session_state.conversation_history = {} | |
| # Load YOLOv8 model | |
| yolo_model = YOLO("./model/plantdoc_model_yolov8.pt") | |
| def preprocess_image(image, target_size=(224, 224)): | |
| """ | |
| Preprocess the image for vision-capable models. | |
| """ | |
| image = Image.fromarray(image) | |
| image = image.resize(target_size) | |
| return image | |
| def text_to_speech(text, voice="af_heart", language="en"): | |
| """Convert text to speech using Kokoro TTS with local voice files.""" | |
| try: | |
| # Initialize Kokoro pipeline | |
| pipeline = KPipeline(lang_code="a") # 'a' for American English, 'b' for British | |
| # Ensure selected voice exists | |
| voice_path = os.path.join(VOICES_DIR, f"{voice}.pt") | |
| print(f"Loading voice file: {voice_path}") # Debugging step | |
| if not os.path.exists(voice_path): | |
| raise FileNotFoundError(f"Voice file '{voice_path}' not found. Please check your Kokoro directory.") | |
| # Generate speech | |
| generator = pipeline(text, voice=voice, speed=1, split_pattern=r"\n+") | |
| audio_data = [] | |
| for _, _, audio in generator: | |
| audio_data.extend(audio) | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| # Save to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
| sf.write(temp_audio.name, audio_array, 24000) | |
| # Read the audio file | |
| with open(temp_audio.name, 'rb') as audio_file: | |
| audio_bytes = audio_file.read() | |
| os.unlink(temp_audio.name) # Clean up temp file | |
| return audio_bytes | |
| except FileNotFoundError as e: | |
| st.error(f"Error: {str(e)}") | |
| return None | |
| except Exception as e: | |
| st.error(f"Error generating speech: {str(e)}") | |
| return None | |
| async def generate_rag_response_general(query, conversation_history=None): | |
| """ | |
| Generate a response using RAG for general questions (no specific detected diseases) | |
| """ | |
| # Retrieve relevant chunks based on the query | |
| relevant_chunks = retrieve_relevant_documents(query, [], top_k=5) # Empty disease list for general queries | |
| # Build context from retrieved chunks | |
| context = "\n".join([chunk["text"] for chunk in relevant_chunks]) | |
| # Create a more general prompt for consultation | |
| consultation_prompt = f"""As an expert plant pathologist and agricultural consultant, please provide a comprehensive answer to the following question about crop diseases and plant health. | |
| Context from knowledge base: | |
| {context} | |
| Question: {query} | |
| Please provide a detailed, practical response that includes: | |
| 1. Direct answer to the question | |
| 2. Relevant scientific background | |
| 3. Practical recommendations | |
| 4. Prevention strategies (if applicable) | |
| 5. When to seek professional help (if applicable) | |
| Make your response accessible to farmers and agricultural practitioners while maintaining scientific accuracy.""" | |
| # Generate response | |
| selected_model_name = SUPPORTED_MODELS[st.session_state.get('selected_model', 'llama-3.1-8b-instant')]["name"] | |
| response = await generate_groq_response( | |
| consultation_prompt, | |
| model_name=selected_model_name, | |
| conversation_history=conversation_history | |
| ) | |
| # Evaluate using local metrics (simplified for general consultation) | |
| evaluator = LocalMetricsEvaluator() | |
| ragas_result = { | |
| "answer_relevancy": evaluator.evaluate_answer_relevancy(query, response), | |
| "faithfulness": evaluator.evaluate_faithfulness(response, context), | |
| "answer_correctness": 0.8, # Placeholder since we don't have ground truth for general questions | |
| "context_precision": evaluator.evaluate_context_precision(query, context), | |
| "context_recall": 0.8 # Placeholder | |
| } | |
| return response, relevant_chunks, ragas_result | |
| async def generate_groq_response(prompt, model_name="mixtral-8x7b-32768", conversation_history=None): | |
| try: | |
| # Build the messages array | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful plant pathology expert assistant."} | |
| ] | |
| # Add conversation history | |
| if conversation_history: | |
| for entry in conversation_history: | |
| if len(entry) >= 2: # Handle tuples with 2 or 3 values | |
| question, response = entry[:2] | |
| messages.extend([ | |
| {"role": "user", "content": question}, | |
| {"role": "assistant", "content": response} | |
| ]) | |
| # Add the current prompt | |
| messages.append({"role": "user", "content": prompt}) | |
| # Generate response using Groq | |
| response = groq_client.chat.completions.create( | |
| model=model_name, | |
| messages=messages | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error connecting to Groq: {str(e)}" | |
| def generate_improved_description(detected_classes, class_names, user_text, image_details=None, conversation_history=None): | |
| """ | |
| Generate a more detailed and contextual description using Ollama | |
| """ | |
| detected_objects = [class_names[cls] for cls in detected_classes] | |
| # Create base context about detected diseases | |
| disease_context = f"Detected diseases: {', '.join(detected_objects)}" | |
| # Different prompt structure for initial vs. follow-up questions | |
| if not conversation_history: | |
| base_prompt = f"""As an expert plant pathologist, analyze the following crop diseases detected in the provided image: {', '.join(detected_objects)}. | |
| For each detected disease, provide a structured analysis following this format: | |
| 1. Disease Name: [Name] | |
| - Pathogen: [Causative organism] | |
| - Severity Level: [Based on visual symptoms] | |
| - Key Symptoms: | |
| * [Symptom 1] | |
| * [Symptom 2] | |
| - Economic Impact: | |
| * [Brief description of potential crop losses] | |
| - Treatment Options: | |
| * Immediate actions: [Short-term solutions] | |
| * Long-term management: [Preventive measures] | |
| - Environmental Conditions: | |
| * Favorable conditions for disease development | |
| * Risk factors | |
| 2. Recommendations: | |
| - Immediate Steps: | |
| * [Action items for immediate control] | |
| - Prevention Strategy: | |
| * [Long-term prevention measures] | |
| - Monitoring Protocol: | |
| * [What to watch for] | |
| Initial Question/Context: {user_text if user_text else "Provide a general analysis"} | |
| """ | |
| else: | |
| base_prompt = f"""Context: {disease_context} | |
| Previous conversation context has been provided above. Please address the following follow-up question while maintaining consistency with previous responses: | |
| {user_text} | |
| Provide a detailed response that builds upon the previous context and specifically addresses this question.""" | |
| # Get the selected model from session state or default to llama2 | |
| selected_model = st.session_state.get('selected_model', 'llama2') | |
| return asyncio.run(generate_groq_response( | |
| base_prompt, | |
| model_name=selected_model, | |
| conversation_history=conversation_history, | |
| #image_data=image_details.get("image_data") if image_details else None | |
| )) | |
| def inference(image): | |
| """ | |
| Enhanced inference function with confidence scores and bounding box information | |
| """ | |
| results = yolo_model(image, conf=0.4) | |
| infer = np.zeros(image.shape, dtype=np.uint8) | |
| classes = dict() | |
| names_infer = [] | |
| confidence_scores = [] | |
| bounding_boxes = [] | |
| for r in results: | |
| infer = r.plot() | |
| classes = r.names | |
| names_infer = r.boxes.cls.tolist() | |
| confidence_scores = r.boxes.conf.tolist() | |
| bounding_boxes = r.boxes.xyxy.tolist() | |
| return infer, names_infer, classes, confidence_scores, bounding_boxes | |
| # Streamlit application | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("🔬 Research Tools") | |
| # Page selection | |
| page_selection = st.sidebar.radio( | |
| "Navigate to:", | |
| ["🏠 Main App", "🔬 Research Dashboard"], | |
| index=0 | |
| ) | |
| if page_selection == "🔬 Research Dashboard": | |
| render_research_page() | |
| else: | |
| # Your existing main app code | |
| st.title("Interactive Crop Disease Detection and Analysis🌾🌿🥬☘️") | |
| st.write(f"Welcome, {st.session_state['username']}!😊") | |
| # Logout button | |
| if st.button("Logout"): | |
| logout() | |
| st.rerun() | |
| # Add sidebar for configuration | |
| with st.sidebar: | |
| st.header("Settings") | |
| selected_model = st.selectbox( | |
| "Select LLM Model", | |
| list(SUPPORTED_MODELS.keys()), | |
| help="Choose the Ollama model to use for analysis" | |
| ) | |
| # Store the selected model in session state | |
| st.session_state['selected_model'] = selected_model | |
| if SUPPORTED_MODELS[selected_model]["supports_vision"]: | |
| st.info("This model supports vision capabilities and can analyze images directly.") | |
| confidence_threshold = st.slider("Detection Confidence Threshold", 0.0, 1.0, 0.4) | |
| show_confidence = st.checkbox("Show Confidence Scores", value=True) | |
| show_bbox = st.checkbox("Show Bounding Boxes", value=True) | |
| # TTS Settings | |
| # Path to your Kokoro repository | |
| KOKORO_DIR = "./Kokoro-82M" | |
| VOICES_DIR = os.path.join(KOKORO_DIR, "voices") | |
| # Automatically list available voices | |
| available_voices = [f.replace(".pt", "") for f in os.listdir(VOICES_DIR) if f.endswith(".pt")] | |
| # Ensure there are voices available | |
| if not available_voices: | |
| available_voices = ["af_heart"] # Default fallback voice if directory is empty | |
| # Streamlit voice selection dropdown | |
| selected_voice = st.sidebar.selectbox("Choose a Voice", available_voices, index=0) | |
| # Add option to clear conversation history | |
| if st.button("Clear All Conversations"): | |
| st.session_state.conversation_history = {} | |
| st.success("Conversation history cleared!") | |
| # Initialize translator | |
| translator = Translator() | |
| # Language selection | |
| language = st.selectbox( | |
| "Select Language", | |
| options=['en', 'es', 'fr', 'de', 'ak', 'gaa', 'ee'], # Add more languages as needed | |
| format_func=lambda x: { | |
| 'en': 'English', | |
| 'es': 'Spanish', | |
| 'fr': 'French', | |
| 'de': 'German', | |
| 'ak': 'Twi', | |
| 'gaa': 'Ga', | |
| 'ee': 'Ewe' | |
| }[x], | |
| help="Select your preferred language" | |
| ) | |
| tab1, tab2 = st.tabs(["🖼️ Image Analysis", "💬 General Consultation"]) | |
| with tab1: | |
| st.header("Image-Based Disease Detection") | |
| st.write("Upload images of your crops to detect diseases and get specific analysis.") | |
| # Main content - Image upload and analysis | |
| uploaded_files = st.file_uploader("Upload images for disease detection", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| file_id = uploaded_file.name | |
| # Initialize conversation history for this image if it doesn't exist | |
| if file_id not in st.session_state.conversation_history: | |
| st.session_state.conversation_history[file_id] = [] | |
| st.subheader(f"Analysis for {file_id}") | |
| # Create columns for side-by-side display | |
| col1, col2 = st.columns(2) | |
| # Process image | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, 1) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Display original image | |
| with col1: | |
| st.subheader("Original Image") | |
| st.image(image) | |
| # Process and display results | |
| with st.spinner("Processing image..."): | |
| infer_image, classes_in_image, classes_in_dataset, confidences, boxes = inference(image) | |
| with col2: | |
| st.subheader("Detected Diseases") | |
| st.image(infer_image) | |
| # Display detection details | |
| if show_confidence: | |
| st.subheader("Detection Details") | |
| for cls, conf in zip(classes_in_image, confidences): | |
| st.write(f"- {classes_in_dataset[cls]}: {conf:.2%} confidence") | |
| # Display conversation history for this image | |
| if st.session_state.conversation_history[file_id]: | |
| st.subheader("Conversation History") | |
| for i, entry in enumerate(st.session_state.conversation_history[file_id]): | |
| question, response = entry[:2] | |
| with st.expander(f"Q{i+1}: {question[:50]}...", expanded=False): | |
| st.write("**Question:**", question) | |
| st.write("**Response:**", response) | |
| # Display feedback buttons and handle comment collection | |
| display_feedback_buttons(file_id, i, question, response) | |
| # Audio playback option | |
| if st.button("🔊 Listen", key=f"listen_history_{file_id}_{i}"): | |
| with st.spinner("Generating audio..."): | |
| audio_bytes = text_to_speech(response, voice=selected_voice) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/wav") | |
| # User input for questions about the detected diseases | |
| st.subheader("Ask Questions About Detected Diseases") | |
| user_text = st.text_area( | |
| "Enter your question about the detected diseases:", | |
| placeholder="Example: What are the best treatment options for these diseases? What preventive measures should I take?", | |
| key=f"question_{file_id}" | |
| ) | |
| if st.button("Get Analysis", key=f"analyze_{file_id}"): | |
| with st.spinner(f"Generating analysis using {selected_model}..."): | |
| # Translate user input | |
| translated_input = asyncio.run(translator.translate(user_text, dest='en')).text | |
| st.write(f"Translated Input (to English): {translated_input}") | |
| # Extract detected disease names | |
| detected_classes = [classes_in_dataset[cls] for cls in classes_in_image] | |
| # Fetch reference answers for detected diseases | |
| reference_answers = [] | |
| for disease_name in detected_classes: | |
| reference_answer = get_reference_answer(disease_name) | |
| if reference_answer: | |
| reference_answers.append(reference_answer) | |
| # Combine reference answers into a single string | |
| reference_answer = "\n".join(reference_answers) if reference_answers else None | |
| # Generate response with RAG | |
| response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response( | |
| translated_input, | |
| st.session_state.conversation_history[file_id], | |
| reference_answer # Pass the reference answer for evaluation | |
| )) | |
| print("Response:", response) | |
| if response is None: | |
| st.error("Failed to generate a response. Please try again.") | |
| response = "No response generated." | |
| # Move the translate function call here | |
| if response: | |
| try: | |
| translated_response = asyncio.run(translator.translate(response, dest=language)).text | |
| except Exception as e: | |
| st.error(f"Translation failed: {e}") | |
| translated_response = response # Fallback to the original response | |
| else: | |
| translated_response = response | |
| st.session_state.conversation_history[file_id].append((user_text, translated_response, None)) | |
| # Display the response and evaluation metrics | |
| #st.markdown("### Relevant Information") | |
| #for chunk in relevant_chunks: | |
| # st.write(f"- **Chunk {chunk['chunk_number']}**: {chunk['text']}") | |
| st.markdown(response) | |
| # Add audio playback option for the latest response | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| if st.button("🔊 Listen", key=f"listen_latest_{file_id}"): | |
| with st.spinner("Generating audio..."): | |
| audio_bytes = text_to_speech(response, language) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format='audio/mp3') | |
| with tab2: | |
| st.header("General Disease Consultation") | |
| st.write("Ask questions about crop diseases without uploading images. Get expert advice on plant pathology topics.") | |
| # Initialize general consultation history | |
| if 'general_consultation' not in st.session_state.conversation_history: | |
| st.session_state.conversation_history['general_consultation'] = [] | |
| # Disease selection helper | |
| st.subheader("🎯 Quick Disease Lookup") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| # Get list of diseases from database for quick selection | |
| try: | |
| conn = sqlite3.connect('./db/disease_knowledge_base.db') | |
| c = conn.cursor() | |
| c.execute("SELECT DISTINCT name FROM diseases ORDER BY name") | |
| available_diseases = [row[0] for row in c.fetchall()] | |
| conn.close() | |
| except: | |
| available_diseases = ["Corn Leaf Blight", "Apple Scab", "Tomato Late Blight", "Wheat Rust"] | |
| selected_disease = st.selectbox( | |
| "Select a specific disease for quick information:", | |
| [""] + available_diseases, | |
| help="Choose a disease to get instant information about it" | |
| ) | |
| with col2: | |
| if selected_disease and st.button("Get Disease Info", key="quick_disease_info"): | |
| with st.spinner("Retrieving disease information..."): | |
| quick_query = f"Tell me about {selected_disease} - its causes, symptoms, and treatment options." | |
| # Generate response using RAG | |
| response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response_general( | |
| quick_query, | |
| st.session_state.conversation_history['general_consultation'] | |
| )) | |
| # Translate if needed | |
| if language != 'en': | |
| try: | |
| translated_response = translator.translate(response, dest=language).text | |
| except: | |
| translated_response = response | |
| else: | |
| translated_response = response | |
| # Add to conversation history | |
| st.session_state.conversation_history['general_consultation'].append((quick_query, translated_response)) | |
| st.markdown("### Disease Information") | |
| st.markdown(translated_response) | |
| # Audio option | |
| if st.button("🔊 Listen to Response", key="listen_quick_disease"): | |
| with st.spinner("Generating audio..."): | |
| audio_bytes = text_to_speech(translated_response, voice=selected_voice) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/wav") | |
| # General question input | |
| st.subheader("💡 Ask Any Question About Crop Diseases") | |
| # Provide example questions | |
| example_questions = [ | |
| "What are the most common fungal diseases in tomatoes?", | |
| "How can I prevent wheat rust in my field?", | |
| "What's the difference between bacterial and viral plant diseases?", | |
| "Which organic treatments work best for aphid control?", | |
| "What are the early signs of nutrient deficiency in corn?", | |
| "How do weather conditions affect plant disease development?", | |
| ] | |
| with st.expander("💡 Example Questions", expanded=False): | |
| for i, example in enumerate(example_questions): | |
| if st.button(example, key=f"example_{i}"): | |
| st.session_state[f"general_question_input"] = example | |
| general_question = st.text_area( | |
| "Enter your question about crop diseases, plant pathology, or agricultural practices:", | |
| placeholder="Example: What are the most effective organic methods to control powdery mildew in grapes?", | |
| key="general_question_input", | |
| height=100 | |
| ) | |
| # Topic categories for better organization | |
| st.subheader("🏷️ Question Categories") | |
| col1, col2, col3 = st.columns(3) | |
| # Define callback functions | |
| def set_treatment_question(): | |
| st.session_state.general_question_input = "What are the most effective treatment options for fungal plant diseases?" | |
| def set_identification_question(): | |
| st.session_state.general_question_input = "How can I identify different types of plant diseases based on symptoms?" | |
| def set_prevention_question(): | |
| st.session_state.general_question_input = "What preventive measures can I take to protect my crops from diseases?" | |
| # Then modify your button calls: | |
| with col1: | |
| st.button("🦠 Disease Identification", key="cat_identification", on_click=set_identification_question) | |
| with col2: | |
| st.button("💊 Treatment Options", key="cat_treatment", on_click=set_treatment_question) | |
| with col3: | |
| st.button("🛡️ Prevention Methods", key="cat_prevention", on_click=set_prevention_question) | |
| if st.button("Get Expert Answer", key="general_analyze", type="primary"): | |
| if general_question.strip(): | |
| with st.spinner(f"Consulting plant pathology expert using {selected_model}..."): | |
| # Translate user input if needed | |
| if language != 'en': | |
| try: | |
| translated_input = translator.translate(general_question, dest='en').text | |
| st.info(f"Translated to English: {translated_input}") | |
| except: | |
| translated_input = general_question | |
| else: | |
| translated_input = general_question | |
| # Generate response using RAG for general consultation | |
| response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response_general( | |
| translated_input, | |
| st.session_state.conversation_history['general_consultation'] | |
| )) | |
| if response: | |
| # Translate response back to user's language | |
| if language != 'en': | |
| try: | |
| translated_response = translator.translate(response, dest=language).text | |
| except Exception as e: | |
| st.error(f"Translation failed: {e}") | |
| translated_response = response | |
| else: | |
| translated_response = response | |
| # Add to conversation history | |
| st.session_state.conversation_history['general_consultation'].append((general_question, translated_response)) | |
| # Display response | |
| st.markdown("### Expert Response") | |
| st.markdown(translated_response) | |
| # Show relevant sources if available | |
| if relevant_chunks: | |
| with st.expander("📚 Information Sources", expanded=False): | |
| for i, chunk in enumerate(relevant_chunks[:3]): # Show top 3 sources | |
| st.write(f"**Source {i+1}:** {chunk['text'][:200]}...") | |
| # Audio playback option | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| if st.button("🔊 Listen", key="listen_general_latest"): | |
| with st.spinner("Generating audio..."): | |
| audio_bytes = text_to_speech(translated_response, voice=selected_voice) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/wav") | |
| else: | |
| st.error("Failed to generate a response. Please try again.") | |
| else: | |
| st.warning("Please enter a question before submitting.") | |
| # Display general consultation history | |
| if st.session_state.conversation_history['general_consultation']: | |
| st.subheader("📝 Consultation History") | |
| for i, entry in enumerate(st.session_state.conversation_history['general_consultation']): | |
| question, response = entry[:2] | |
| with st.expander(f"Q{i+1}: {question[:60]}...", expanded=False): | |
| st.write("**Question:**", question) | |
| st.write("**Response:**", response) | |
| # Feedback buttons for general consultation | |
| display_feedback_buttons('general_consultation', i, question, response) | |
| # Audio playback for history | |
| if st.button("🔊 Listen", key=f"listen_general_history_{i}"): | |
| with st.spinner("Generating audio..."): | |
| audio_bytes = text_to_speech(response, voice=selected_voice) | |
| if audio_bytes: | |
| st.audio(audio_bytes, format="audio/wav") | |
| # Export general consultation | |
| if st.session_state.conversation_history['general_consultation']: | |
| if st.button("📄 Export Consultation", key="export_general"): | |
| consultation_text = f""" | |
| # General Crop Disease Consultation Report | |
| ## Consultation Information | |
| - Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| - Language: {language} | |
| - Model Used: {selected_model} | |
| ## Consultation History | |
| """ | |
| for i, entry in enumerate(st.session_state.conversation_history['general_consultation']): | |
| question, response = entry[:2] | |
| consultation_text += f"\n### Question {i+1}:\n{question}\n\n### Expert Response {i+1}:\n{response}\n\n---\n" | |
| st.download_button( | |
| label="📥 Download Consultation Report", | |
| data=consultation_text, | |
| file_name=f"crop_disease_consultation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md", | |
| mime="text/markdown", | |
| key="download_general" | |
| ) | |
| # Add a footer with clear instructions | |
| st.markdown(""" | |
| --- | |
| ### How to Use | |
| 1. Upload one or more images of crops with potential diseases | |
| 2. View the detected diseases and their confidence scores | |
| 3. Ask questions about the diseases, treatments, or prevention | |
| 4. Use the 🔊 Listen button to hear the responses | |
| 5. View previous questions and answers in the conversation history | |
| 6. Export the entire conversation for future reference | |
| 7. Use the sidebar to adjust settings or clear conversation history | |
| """) |