Spaces:
Build error
Build error
Robert Yaw Agyekum Addo commited on
Commit ·
0ee82c6
1
Parent(s): a219526
Modified app.py with ablation experiments
Browse files
app.py
CHANGED
|
@@ -27,6 +27,11 @@ from sentence_transformers import SentenceTransformer
|
|
| 27 |
# context_precision
|
| 28 |
#)
|
| 29 |
from datasets import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
from typing import List, Dict
|
| 32 |
import asyncio
|
|
@@ -76,6 +81,468 @@ class SentenceTransformerEmbeddings(BaseRagasEmbeddings):
|
|
| 76 |
loop = asyncio.get_event_loop()
|
| 77 |
return await loop.run_in_executor(None, self.embed_documents, texts)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Database setup
|
| 80 |
conn = sqlite3.connect('users.db')
|
| 81 |
c = conn.cursor()
|
|
@@ -1270,7 +1737,7 @@ if 'conversation_history' not in st.session_state:
|
|
| 1270 |
st.session_state.conversation_history = {}
|
| 1271 |
|
| 1272 |
# Load YOLOv8 model
|
| 1273 |
-
yolo_model = YOLO("
|
| 1274 |
|
| 1275 |
def preprocess_image(image, target_size=(224, 224)):
|
| 1276 |
"""
|
|
@@ -1319,6 +1786,53 @@ def text_to_speech(text, voice="af_heart", language="en"):
|
|
| 1319 |
except Exception as e:
|
| 1320 |
st.error(f"Error generating speech: {str(e)}")
|
| 1321 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1322 |
|
| 1323 |
async def generate_groq_response(prompt, model_name="mixtral-8x7b-32768", conversation_history=None):
|
| 1324 |
try:
|
|
@@ -1431,13 +1945,27 @@ def inference(image):
|
|
| 1431 |
return infer, names_infer, classes, confidence_scores, bounding_boxes
|
| 1432 |
|
| 1433 |
# Streamlit application
|
| 1434 |
-
st.
|
| 1435 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1436 |
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1441 |
|
| 1442 |
# Add sidebar for configuration
|
| 1443 |
with st.sidebar:
|
|
@@ -1496,165 +2024,344 @@ language = st.selectbox(
|
|
| 1496 |
help="Select your preferred language"
|
| 1497 |
)
|
| 1498 |
|
| 1499 |
-
|
| 1500 |
-
uploaded_files = st.file_uploader("Upload images for disease detection", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
| 1501 |
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
-
|
| 1507 |
-
|
| 1508 |
-
|
| 1509 |
-
|
| 1510 |
-
|
| 1511 |
-
|
| 1512 |
-
|
| 1513 |
-
|
| 1514 |
-
|
| 1515 |
-
|
| 1516 |
-
|
| 1517 |
-
|
| 1518 |
-
|
| 1519 |
-
|
| 1520 |
-
|
| 1521 |
-
|
| 1522 |
-
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
|
| 1532 |
-
|
| 1533 |
-
|
| 1534 |
-
|
| 1535 |
-
|
| 1536 |
-
|
| 1537 |
-
st.
|
| 1538 |
-
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
|
| 1547 |
-
|
|
|
|
|
|
|
|
|
|
| 1548 |
|
| 1549 |
-
|
| 1550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1551 |
|
| 1552 |
-
#
|
| 1553 |
-
|
| 1554 |
-
|
| 1555 |
-
|
| 1556 |
-
|
| 1557 |
-
|
| 1558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1559 |
|
| 1560 |
-
|
| 1561 |
-
|
| 1562 |
-
|
| 1563 |
-
"
|
| 1564 |
-
placeholder="Example: What are the best treatment options for these diseases? What preventive measures should I take?",
|
| 1565 |
-
key=f"question_{file_id}"
|
| 1566 |
)
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
-
|
| 1570 |
-
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1576 |
|
| 1577 |
-
#
|
| 1578 |
-
|
| 1579 |
-
for disease_name in detected_classes:
|
| 1580 |
-
reference_answer = get_reference_answer(disease_name)
|
| 1581 |
-
if reference_answer:
|
| 1582 |
-
reference_answers.append(reference_answer)
|
| 1583 |
|
| 1584 |
-
#
|
| 1585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1586 |
|
| 1587 |
-
# Generate response
|
| 1588 |
-
response, relevant_chunks, ragas_result = asyncio.run(
|
| 1589 |
translated_input,
|
| 1590 |
-
st.session_state.conversation_history[
|
| 1591 |
-
reference_answer # Pass the reference answer for evaluation
|
| 1592 |
))
|
| 1593 |
-
print("Response:", response)
|
| 1594 |
|
| 1595 |
-
if response is None:
|
| 1596 |
-
st.error("Failed to generate a response. Please try again.")
|
| 1597 |
-
response = "No response generated."
|
| 1598 |
-
|
| 1599 |
-
# Move the translate function call here
|
| 1600 |
if response:
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
|
| 1607 |
-
|
|
|
|
|
|
|
| 1608 |
|
| 1609 |
-
|
|
|
|
| 1610 |
|
| 1611 |
-
|
| 1612 |
-
|
| 1613 |
-
|
| 1614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1615 |
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
|
| 1624 |
-
if audio_bytes:
|
| 1625 |
-
st.audio(audio_bytes, format='audio/mp3')
|
| 1626 |
-
|
| 1627 |
-
# Export conversation
|
| 1628 |
-
if st.button("Export Conversation", key=f"export_{file_id}"):
|
| 1629 |
-
conversation_text = f"""
|
| 1630 |
-
# Crop Disease Analysis Report
|
| 1631 |
-
|
| 1632 |
-
## Image Information
|
| 1633 |
-
- Filename: {file_id}
|
| 1634 |
-
- Analysis Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 1635 |
-
|
| 1636 |
-
## Detected Diseases
|
| 1637 |
-
{', '.join([classes_in_dataset[cls] for cls in classes_in_image])}
|
| 1638 |
-
|
| 1639 |
-
## Conversation History
|
| 1640 |
-
"""
|
| 1641 |
-
|
| 1642 |
-
for i, entry in enumerate(st.session_state.conversation_history[file_id]):
|
| 1643 |
-
if len(entry) == 2: # Handle legacy entries
|
| 1644 |
-
question, response = entry
|
| 1645 |
-
feedback = "No feedback"
|
| 1646 |
else:
|
| 1647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1648 |
|
| 1649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1650 |
|
| 1651 |
st.download_button(
|
| 1652 |
-
label="Download
|
| 1653 |
-
data=
|
| 1654 |
-
file_name=f"
|
| 1655 |
-
mime="text/markdown"
|
|
|
|
| 1656 |
)
|
| 1657 |
-
|
| 1658 |
# Add a footer with clear instructions
|
| 1659 |
st.markdown("""
|
| 1660 |
---
|
|
|
|
| 27 |
# context_precision
|
| 28 |
#)
|
| 29 |
from datasets import Dataset
|
| 30 |
+
import pandas as pd
|
| 31 |
+
import random
|
| 32 |
+
import plotly.express as px
|
| 33 |
+
import plotly.graph_objects as go
|
| 34 |
+
from plotly.subplots import make_subplots
|
| 35 |
|
| 36 |
from typing import List, Dict
|
| 37 |
import asyncio
|
|
|
|
| 81 |
loop = asyncio.get_event_loop()
|
| 82 |
return await loop.run_in_executor(None, self.embed_documents, texts)
|
| 83 |
|
| 84 |
+
class RAGSystemVariants:
|
| 85 |
+
def __init__(self):
|
| 86 |
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 87 |
+
|
| 88 |
+
async def baseline_rag(self, query, top_k=3):
|
| 89 |
+
"""Your current full RAG system"""
|
| 90 |
+
chunks = retrieve_relevant_documents(query, [], top_k)
|
| 91 |
+
context = "\n".join([chunk["text"] for chunk in chunks])
|
| 92 |
+
response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}")
|
| 93 |
+
return response, context
|
| 94 |
+
|
| 95 |
+
async def no_retrieval(self, query):
|
| 96 |
+
"""Generation only - no retrieval"""
|
| 97 |
+
response = await generate_groq_response(query)
|
| 98 |
+
return response, ""
|
| 99 |
+
|
| 100 |
+
async def random_retrieval(self, query, top_k=3):
|
| 101 |
+
"""Random document selection instead of semantic retrieval"""
|
| 102 |
+
try:
|
| 103 |
+
all_docs = client.scroll(collection_name=collection_name, limit=100)[0]
|
| 104 |
+
if len(all_docs) > 0:
|
| 105 |
+
random_chunks = random.sample(all_docs, min(top_k, len(all_docs)))
|
| 106 |
+
context = "\n".join([chunk.payload["text"] for chunk in random_chunks])
|
| 107 |
+
else:
|
| 108 |
+
context = ""
|
| 109 |
+
response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}")
|
| 110 |
+
return response, context
|
| 111 |
+
except Exception as e:
|
| 112 |
+
st.error(f"Error in random retrieval: {e}")
|
| 113 |
+
return "Error in random retrieval", ""
|
| 114 |
+
|
| 115 |
+
async def different_top_k(self, query, top_k):
|
| 116 |
+
"""Test different top-k values"""
|
| 117 |
+
chunks = retrieve_relevant_documents(query, [], top_k)
|
| 118 |
+
context = "\n".join([chunk["text"] for chunk in chunks])
|
| 119 |
+
response = await generate_groq_response(f"Context: {context}\n\nQuestion: {query}")
|
| 120 |
+
return response, context
|
| 121 |
+
|
| 122 |
+
def create_test_dataset(limit=20):
|
| 123 |
+
"""Create a test dataset for RAGAS evaluation"""
|
| 124 |
+
test_cases = []
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
conn = sqlite3.connect('./db/disease_knowledge_base.db')
|
| 128 |
+
c = conn.cursor()
|
| 129 |
+
c.execute("SELECT name, cause, symptoms, treatment FROM diseases LIMIT ?", (limit,))
|
| 130 |
+
diseases = c.fetchall()
|
| 131 |
+
conn.close()
|
| 132 |
+
|
| 133 |
+
for disease_name, cause, symptoms, treatment in diseases:
|
| 134 |
+
questions_and_answers = [
|
| 135 |
+
(f"What causes {disease_name}?", cause),
|
| 136 |
+
(f"What are the symptoms of {disease_name}?", symptoms),
|
| 137 |
+
(f"How do I treat {disease_name}?", treatment),
|
| 138 |
+
(f"Tell me about {disease_name}", f"Cause: {cause}\nSymptoms: {symptoms}\nTreatment: {treatment}"),
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
for question, ground_truth in questions_and_answers:
|
| 142 |
+
test_cases.append({
|
| 143 |
+
"question": question,
|
| 144 |
+
"ground_truth": ground_truth,
|
| 145 |
+
"disease": disease_name
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
return test_cases[:limit]
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
st.error(f"Error creating test dataset: {e}")
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
async def run_ablation_study(progress_bar, status_text, max_questions=20):
|
| 155 |
+
"""Run comprehensive ablation study with progress tracking"""
|
| 156 |
+
|
| 157 |
+
status_text.text("Creating test dataset...")
|
| 158 |
+
test_cases = create_test_dataset(limit=max_questions)
|
| 159 |
+
|
| 160 |
+
if not test_cases:
|
| 161 |
+
st.error("No test cases created. Check your database connection.")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
rag_variants = RAGSystemVariants()
|
| 165 |
+
|
| 166 |
+
experiments = {
|
| 167 |
+
"Full_RAG_k3": lambda q: rag_variants.baseline_rag(q, top_k=3),
|
| 168 |
+
"No_Retrieval": lambda q: rag_variants.no_retrieval(q),
|
| 169 |
+
"Random_Retrieval": lambda q: rag_variants.random_retrieval(q, top_k=3),
|
| 170 |
+
"RAG_k1": lambda q: rag_variants.different_top_k(q, top_k=1),
|
| 171 |
+
"RAG_k5": lambda q: rag_variants.different_top_k(q, top_k=5),
|
| 172 |
+
"RAG_k10": lambda q: rag_variants.different_top_k(q, top_k=10),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
all_results = []
|
| 176 |
+
total_experiments = len(experiments) * len(test_cases)
|
| 177 |
+
current_progress = 0
|
| 178 |
+
|
| 179 |
+
for exp_name, exp_func in experiments.items():
|
| 180 |
+
status_text.text(f"Running experiment: {exp_name}")
|
| 181 |
+
|
| 182 |
+
questions = []
|
| 183 |
+
answers = []
|
| 184 |
+
contexts = []
|
| 185 |
+
ground_truths = []
|
| 186 |
+
|
| 187 |
+
for test_case in test_cases:
|
| 188 |
+
try:
|
| 189 |
+
answer, context = await exp_func(test_case["question"])
|
| 190 |
+
|
| 191 |
+
questions.append(test_case["question"])
|
| 192 |
+
answers.append(answer)
|
| 193 |
+
contexts.append([context] if context else [""])
|
| 194 |
+
ground_truths.append(test_case["ground_truth"])
|
| 195 |
+
|
| 196 |
+
current_progress += 1
|
| 197 |
+
progress_bar.progress(current_progress / total_experiments)
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
st.error(f"Error in {exp_name}: {e}")
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
exp_results = []
|
| 204 |
+
evaluator = LocalMetricsEvaluator()
|
| 205 |
+
|
| 206 |
+
for q, a, c, gt in zip(questions, answers, contexts, ground_truths):
|
| 207 |
+
context_str = c[0] if c and c[0] else ""
|
| 208 |
+
metrics = {
|
| 209 |
+
"question": q,
|
| 210 |
+
"answer": a,
|
| 211 |
+
"context": context_str,
|
| 212 |
+
"ground_truth": gt,
|
| 213 |
+
"experiment": exp_name,
|
| 214 |
+
"answer_relevancy": evaluator.evaluate_answer_relevancy(q, a),
|
| 215 |
+
"faithfulness": evaluator.evaluate_faithfulness(a, context_str) if context_str else 1.0,
|
| 216 |
+
"answer_correctness": evaluator.evaluate_answer_correctness(a, gt),
|
| 217 |
+
"context_precision": evaluator.evaluate_context_precision(q, context_str) if context_str else 0.0,
|
| 218 |
+
"context_recall": evaluator.evaluate_context_recall(q, context_str, gt) if context_str else 0.0
|
| 219 |
+
}
|
| 220 |
+
exp_results.append(metrics)
|
| 221 |
+
|
| 222 |
+
all_results.extend(exp_results)
|
| 223 |
+
|
| 224 |
+
return pd.DataFrame(all_results)
|
| 225 |
+
|
| 226 |
+
def visualize_ablation_results(results_df):
|
| 227 |
+
"""Create interactive visualizations for ablation study results"""
|
| 228 |
+
|
| 229 |
+
summary_stats = results_df.groupby('experiment').agg({
|
| 230 |
+
'answer_relevancy': ['mean', 'std'],
|
| 231 |
+
'faithfulness': ['mean', 'std'],
|
| 232 |
+
'answer_correctness': ['mean', 'std'],
|
| 233 |
+
'context_precision': ['mean', 'std'],
|
| 234 |
+
'context_recall': ['mean', 'std']
|
| 235 |
+
}).round(3)
|
| 236 |
+
|
| 237 |
+
summary_stats.columns = ['_'.join(col).strip() for col in summary_stats.columns.values]
|
| 238 |
+
summary_stats = summary_stats.reset_index()
|
| 239 |
+
|
| 240 |
+
metrics = ['answer_relevancy_mean', 'faithfulness_mean', 'answer_correctness_mean',
|
| 241 |
+
'context_precision_mean', 'context_recall_mean']
|
| 242 |
+
|
| 243 |
+
# Radar chart
|
| 244 |
+
fig_radar = go.Figure()
|
| 245 |
+
|
| 246 |
+
for _, row in summary_stats.iterrows():
|
| 247 |
+
fig_radar.add_trace(go.Scatterpolar(
|
| 248 |
+
r=[row[metric] for metric in metrics],
|
| 249 |
+
theta=[metric.replace('_mean', '').replace('_', ' ').title() for metric in metrics],
|
| 250 |
+
fill='toself',
|
| 251 |
+
name=row['experiment']
|
| 252 |
+
))
|
| 253 |
+
|
| 254 |
+
fig_radar.update_layout(
|
| 255 |
+
polar=dict(
|
| 256 |
+
radialaxis=dict(
|
| 257 |
+
visible=True,
|
| 258 |
+
range=[0, 1]
|
| 259 |
+
)),
|
| 260 |
+
showlegend=True,
|
| 261 |
+
title="RAGAS Metrics Comparison Across Experiments"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Bar chart comparison
|
| 265 |
+
fig_bar = make_subplots(
|
| 266 |
+
rows=2, cols=3,
|
| 267 |
+
subplot_titles=[metric.replace('_mean', '').replace('_', ' ').title() for metric in metrics],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
for i, metric in enumerate(metrics):
|
| 271 |
+
row = (i // 3) + 1
|
| 272 |
+
col = (i % 3) + 1
|
| 273 |
+
|
| 274 |
+
fig_bar.add_trace(
|
| 275 |
+
go.Bar(
|
| 276 |
+
x=summary_stats['experiment'],
|
| 277 |
+
y=summary_stats[metric],
|
| 278 |
+
error_y=dict(type='data', array=summary_stats[metric.replace('mean', 'std')]),
|
| 279 |
+
name=metric.replace('_mean', '').replace('_', ' ').title(),
|
| 280 |
+
showlegend=False
|
| 281 |
+
),
|
| 282 |
+
row=row, col=col
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
fig_bar.update_layout(height=800, title="Detailed Metrics Comparison")
|
| 286 |
+
|
| 287 |
+
return fig_radar, fig_bar, summary_stats
|
| 288 |
+
|
| 289 |
+
def render_research_page():
|
| 290 |
+
"""Render the research/ablation study page"""
|
| 291 |
+
|
| 292 |
+
st.title("🔬 RAG System Research Dashboard")
|
| 293 |
+
st.markdown("Systematic evaluation and ablation study of the crop disease detection RAG system")
|
| 294 |
+
|
| 295 |
+
# Initialize session state for results
|
| 296 |
+
if 'ablation_results' not in st.session_state:
|
| 297 |
+
st.session_state['ablation_results'] = None
|
| 298 |
+
|
| 299 |
+
tabs = st.tabs(["Ablation Study", "Model Comparison", "Error Analysis", "Export Results"])
|
| 300 |
+
|
| 301 |
+
with tabs[0]:
|
| 302 |
+
st.header("🧪 Ablation Study")
|
| 303 |
+
st.write("This systematically evaluates different components of the RAG system.")
|
| 304 |
+
|
| 305 |
+
col1, col2 = st.columns(2)
|
| 306 |
+
with col1:
|
| 307 |
+
max_questions = st.number_input("Number of test questions per experiment",
|
| 308 |
+
min_value=5, max_value=50, value=20)
|
| 309 |
+
with col2:
|
| 310 |
+
selected_model_research = st.selectbox(
|
| 311 |
+
"Select Model for Experiments",
|
| 312 |
+
list(SUPPORTED_MODELS.keys()),
|
| 313 |
+
key="research_model_select"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if st.button("🚀 Start Ablation Study", type="primary"):
|
| 317 |
+
progress_bar = st.progress(0)
|
| 318 |
+
status_text = st.empty()
|
| 319 |
+
|
| 320 |
+
with st.spinner("Running ablation study..."):
|
| 321 |
+
try:
|
| 322 |
+
results_df = asyncio.run(run_ablation_study(progress_bar, status_text, max_questions))
|
| 323 |
+
|
| 324 |
+
if results_df is not None:
|
| 325 |
+
st.session_state['ablation_results'] = results_df
|
| 326 |
+
|
| 327 |
+
st.success("✅ Ablation study completed!")
|
| 328 |
+
|
| 329 |
+
# Show summary statistics
|
| 330 |
+
st.subheader("📊 Summary Statistics")
|
| 331 |
+
summary_stats = results_df.groupby('experiment').agg({
|
| 332 |
+
'answer_relevancy': 'mean',
|
| 333 |
+
'faithfulness': 'mean',
|
| 334 |
+
'answer_correctness': 'mean',
|
| 335 |
+
'context_precision': 'mean',
|
| 336 |
+
'context_recall': 'mean'
|
| 337 |
+
}).round(3)
|
| 338 |
+
st.dataframe(summary_stats, use_container_width=True)
|
| 339 |
+
|
| 340 |
+
# Key insights
|
| 341 |
+
best_overall = summary_stats.mean(axis=1).idxmax()
|
| 342 |
+
st.success(f"🏆 **Best Overall Configuration:** {best_overall}")
|
| 343 |
+
|
| 344 |
+
col1, col2, col3 = st.columns(3)
|
| 345 |
+
with col1:
|
| 346 |
+
best_relevancy = summary_stats['answer_relevancy'].idxmax()
|
| 347 |
+
st.metric("Best Answer Relevancy", best_relevancy,
|
| 348 |
+
f"{summary_stats.loc[best_relevancy, 'answer_relevancy']:.3f}")
|
| 349 |
+
with col2:
|
| 350 |
+
best_faithfulness = summary_stats['faithfulness'].idxmax()
|
| 351 |
+
st.metric("Best Faithfulness", best_faithfulness,
|
| 352 |
+
f"{summary_stats.loc[best_faithfulness, 'faithfulness']:.3f}")
|
| 353 |
+
with col3:
|
| 354 |
+
best_correctness = summary_stats['answer_correctness'].idxmax()
|
| 355 |
+
st.metric("Best Correctness", best_correctness,
|
| 356 |
+
f"{summary_stats.loc[best_correctness, 'answer_correctness']:.3f}")
|
| 357 |
+
|
| 358 |
+
# Create and display visualizations
|
| 359 |
+
fig_radar, fig_bar, summary_stats_detailed = visualize_ablation_results(results_df)
|
| 360 |
+
|
| 361 |
+
st.subheader("📈 Results Visualization")
|
| 362 |
+
viz_tab1, viz_tab2, viz_tab3 = st.tabs(["Radar Chart", "Detailed Comparison", "Raw Data"])
|
| 363 |
+
|
| 364 |
+
with viz_tab1:
|
| 365 |
+
st.plotly_chart(fig_radar, use_container_width=True)
|
| 366 |
+
st.markdown("**Interpretation:** The radar chart shows the relative performance of each experiment across all RAGAS metrics. Larger areas indicate better overall performance.")
|
| 367 |
+
|
| 368 |
+
with viz_tab2:
|
| 369 |
+
st.plotly_chart(fig_bar, use_container_width=True)
|
| 370 |
+
st.markdown("**Interpretation:** The bar charts show detailed performance with error bars indicating standard deviation across test cases.")
|
| 371 |
+
|
| 372 |
+
with viz_tab3:
|
| 373 |
+
st.dataframe(results_df, use_container_width=True)
|
| 374 |
+
|
| 375 |
+
# Download options
|
| 376 |
+
csv = results_df.to_csv(index=False)
|
| 377 |
+
st.download_button(
|
| 378 |
+
label="📥 Download Raw Results (CSV)",
|
| 379 |
+
data=csv,
|
| 380 |
+
file_name=f"ablation_study_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
| 381 |
+
mime="text/csv"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
st.error(f"❌ Error running ablation study: {str(e)}")
|
| 386 |
+
st.exception(e)
|
| 387 |
+
|
| 388 |
+
with tabs[1]:
|
| 389 |
+
st.header("Model Comparison")
|
| 390 |
+
st.write("Compare different LLM models on the same test dataset.")
|
| 391 |
+
|
| 392 |
+
selected_models = st.multiselect(
|
| 393 |
+
"Select models to compare",
|
| 394 |
+
list(SUPPORTED_MODELS.keys()),
|
| 395 |
+
default=list(SUPPORTED_MODELS.keys())[:2]
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
num_questions_comp = st.number_input("Number of questions for comparison",
|
| 399 |
+
min_value=5, max_value=30, value=10)
|
| 400 |
+
|
| 401 |
+
if selected_models and st.button("🔍 Run Model Comparison"):
|
| 402 |
+
st.info("Model comparison functionality can be extended here...")
|
| 403 |
+
progress_bar_comp = st.progress(0)
|
| 404 |
+
status_text_comp = st.empty()
|
| 405 |
+
|
| 406 |
+
with st.spinner("Comparing models..."):
|
| 407 |
+
# Create a simplified comparison focusing on generation quality
|
| 408 |
+
test_cases = create_test_dataset(limit=num_questions_comp)
|
| 409 |
+
|
| 410 |
+
comparison_results = []
|
| 411 |
+
total_comparisons = len(selected_models) * len(test_cases)
|
| 412 |
+
current_progress_comp = 0
|
| 413 |
+
|
| 414 |
+
for model_name in selected_models:
|
| 415 |
+
status_text_comp.text(f"Testing model: {model_name}")
|
| 416 |
+
|
| 417 |
+
for test_case in test_cases:
|
| 418 |
+
try:
|
| 419 |
+
# Generate response with current model
|
| 420 |
+
response = asyncio.run(generate_groq_response(
|
| 421 |
+
test_case["question"],
|
| 422 |
+
model_name=SUPPORTED_MODELS[model_name]["name"]
|
| 423 |
+
))
|
| 424 |
+
|
| 425 |
+
# Evaluate
|
| 426 |
+
evaluator = LocalMetricsEvaluator()
|
| 427 |
+
comparison_results.append({
|
| 428 |
+
"model": model_name,
|
| 429 |
+
"question": test_case["question"],
|
| 430 |
+
"answer": response,
|
| 431 |
+
"ground_truth": test_case["ground_truth"],
|
| 432 |
+
"disease": test_case["disease"],
|
| 433 |
+
"answer_relevancy": evaluator.evaluate_answer_relevancy(test_case["question"], response),
|
| 434 |
+
"answer_correctness": evaluator.evaluate_answer_correctness(response, test_case["ground_truth"])
|
| 435 |
+
})
|
| 436 |
+
|
| 437 |
+
current_progress_comp += 1
|
| 438 |
+
progress_bar_comp.progress(current_progress_comp / total_comparisons)
|
| 439 |
+
|
| 440 |
+
except Exception as e:
|
| 441 |
+
st.error(f"Error testing {model_name}: {e}")
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
if comparison_results:
|
| 445 |
+
comp_df = pd.DataFrame(comparison_results)
|
| 446 |
+
|
| 447 |
+
# Summary by model
|
| 448 |
+
model_summary = comp_df.groupby('model').agg({
|
| 449 |
+
'answer_relevancy': 'mean',
|
| 450 |
+
'answer_correctness': 'mean'
|
| 451 |
+
}).round(3)
|
| 452 |
+
|
| 453 |
+
st.subheader("📊 Model Performance Summary")
|
| 454 |
+
st.dataframe(model_summary, use_container_width=True)
|
| 455 |
+
|
| 456 |
+
# Visualization
|
| 457 |
+
fig_model_comp = px.bar(
|
| 458 |
+
model_summary.reset_index(),
|
| 459 |
+
x='model',
|
| 460 |
+
y=['answer_relevancy', 'answer_correctness'],
|
| 461 |
+
title="Model Performance Comparison",
|
| 462 |
+
barmode='group'
|
| 463 |
+
)
|
| 464 |
+
st.plotly_chart(fig_model_comp, use_container_width=True)
|
| 465 |
+
|
| 466 |
+
# Store results
|
| 467 |
+
st.session_state['model_comparison_results'] = comp_df
|
| 468 |
+
|
| 469 |
+
with tabs[2]:
|
| 470 |
+
st.header("Error Analysis")
|
| 471 |
+
st.write("Analyze failure cases and performance patterns.")
|
| 472 |
+
|
| 473 |
+
if st.session_state['ablation_results'] is not None:
|
| 474 |
+
results_df = st.session_state['ablation_results']
|
| 475 |
+
|
| 476 |
+
# Find worst performing cases
|
| 477 |
+
st.subheader("Worst Performing Cases")
|
| 478 |
+
worst_cases = results_df.nsmallest(10, 'answer_correctness')[['question', 'answer', 'ground_truth', 'experiment', 'answer_correctness']]
|
| 479 |
+
st.dataframe(worst_cases, use_container_width=True)
|
| 480 |
+
|
| 481 |
+
# Performance by experiment
|
| 482 |
+
st.subheader("Performance Distribution")
|
| 483 |
+
fig_box = px.box(results_df, x='experiment', y='answer_correctness',
|
| 484 |
+
title="Answer Correctness Distribution by Experiment")
|
| 485 |
+
st.plotly_chart(fig_box, use_container_width=True)
|
| 486 |
+
else:
|
| 487 |
+
st.info("Run an ablation study first to see error analysis.")
|
| 488 |
+
|
| 489 |
+
with tabs[3]:
|
| 490 |
+
st.header("Export Results")
|
| 491 |
+
st.write("Export results for research papers and further analysis.")
|
| 492 |
+
|
| 493 |
+
if st.session_state['ablation_results'] is not None:
|
| 494 |
+
results_df = st.session_state['ablation_results']
|
| 495 |
+
|
| 496 |
+
# Generate summary report
|
| 497 |
+
report = f"""
|
| 498 |
+
# RAG System Ablation Study Report
|
| 499 |
+
|
| 500 |
+
**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 501 |
+
**Total Experiments:** {len(results_df['experiment'].unique())}
|
| 502 |
+
**Total Test Cases:** {len(results_df)}
|
| 503 |
+
|
| 504 |
+
## Summary Statistics
|
| 505 |
+
|
| 506 |
+
{results_df.groupby('experiment').agg({
|
| 507 |
+
'answer_relevancy': ['mean', 'std'],
|
| 508 |
+
'faithfulness': ['mean', 'std'],
|
| 509 |
+
'answer_correctness': ['mean', 'std'],
|
| 510 |
+
'context_precision': ['mean', 'std'],
|
| 511 |
+
'context_recall': ['mean', 'std']
|
| 512 |
+
}).round(3).to_string()}
|
| 513 |
+
|
| 514 |
+
## Best Performing Configurations
|
| 515 |
+
|
| 516 |
+
- **Best Answer Relevancy:** {results_df.groupby('experiment')['answer_relevancy'].mean().idxmax()}
|
| 517 |
+
- **Best Faithfulness:** {results_df.groupby('experiment')['faithfulness'].mean().idxmax()}
|
| 518 |
+
- **Best Answer Correctness:** {results_df.groupby('experiment')['answer_correctness'].mean().idxmax()}
|
| 519 |
+
|
| 520 |
+
## Recommendations
|
| 521 |
+
|
| 522 |
+
Based on the ablation study results, we recommend...
|
| 523 |
+
[Add your analysis here]
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
col1, col2 = st.columns(2)
|
| 527 |
+
with col1:
|
| 528 |
+
st.download_button(
|
| 529 |
+
label="📄 Download Report (Markdown)",
|
| 530 |
+
data=report,
|
| 531 |
+
file_name=f"ablation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
|
| 532 |
+
mime="text/markdown"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
with col2:
|
| 536 |
+
csv_data = results_df.to_csv(index=False)
|
| 537 |
+
st.download_button(
|
| 538 |
+
label="📊 Download Data (CSV)",
|
| 539 |
+
data=csv_data,
|
| 540 |
+
file_name=f"ablation_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
| 541 |
+
mime="text/csv"
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
st.info("No results available for export. Run an ablation study first.")
|
| 545 |
+
|
| 546 |
# Database setup
|
| 547 |
conn = sqlite3.connect('users.db')
|
| 548 |
c = conn.cursor()
|
|
|
|
| 1737 |
st.session_state.conversation_history = {}
|
| 1738 |
|
| 1739 |
# Load YOLOv8 model
|
| 1740 |
+
yolo_model = YOLO("/workspaces/codespaces-blank/Areo-AI/model/plantdoc_model_yolov8.pt")
|
| 1741 |
|
| 1742 |
def preprocess_image(image, target_size=(224, 224)):
|
| 1743 |
"""
|
|
|
|
| 1786 |
except Exception as e:
|
| 1787 |
st.error(f"Error generating speech: {str(e)}")
|
| 1788 |
return None
|
| 1789 |
+
|
| 1790 |
+
async def generate_rag_response_general(query, conversation_history=None):
|
| 1791 |
+
"""
|
| 1792 |
+
Generate a response using RAG for general questions (no specific detected diseases)
|
| 1793 |
+
"""
|
| 1794 |
+
# Retrieve relevant chunks based on the query
|
| 1795 |
+
relevant_chunks = retrieve_relevant_documents(query, [], top_k=5) # Empty disease list for general queries
|
| 1796 |
+
|
| 1797 |
+
# Build context from retrieved chunks
|
| 1798 |
+
context = "\n".join([chunk["text"] for chunk in relevant_chunks])
|
| 1799 |
+
|
| 1800 |
+
# Create a more general prompt for consultation
|
| 1801 |
+
consultation_prompt = f"""As an expert plant pathologist and agricultural consultant, please provide a comprehensive answer to the following question about crop diseases and plant health.
|
| 1802 |
+
|
| 1803 |
+
Context from knowledge base:
|
| 1804 |
+
{context}
|
| 1805 |
+
|
| 1806 |
+
Question: {query}
|
| 1807 |
+
|
| 1808 |
+
Please provide a detailed, practical response that includes:
|
| 1809 |
+
1. Direct answer to the question
|
| 1810 |
+
2. Relevant scientific background
|
| 1811 |
+
3. Practical recommendations
|
| 1812 |
+
4. Prevention strategies (if applicable)
|
| 1813 |
+
5. When to seek professional help (if applicable)
|
| 1814 |
+
|
| 1815 |
+
Make your response accessible to farmers and agricultural practitioners while maintaining scientific accuracy."""
|
| 1816 |
+
|
| 1817 |
+
# Generate response
|
| 1818 |
+
selected_model_name = SUPPORTED_MODELS[st.session_state.get('selected_model', 'llama-3.1-8b-instant')]["name"]
|
| 1819 |
+
response = await generate_groq_response(
|
| 1820 |
+
consultation_prompt,
|
| 1821 |
+
model_name=selected_model_name,
|
| 1822 |
+
conversation_history=conversation_history
|
| 1823 |
+
)
|
| 1824 |
+
|
| 1825 |
+
# Evaluate using local metrics (simplified for general consultation)
|
| 1826 |
+
evaluator = LocalMetricsEvaluator()
|
| 1827 |
+
ragas_result = {
|
| 1828 |
+
"answer_relevancy": evaluator.evaluate_answer_relevancy(query, response),
|
| 1829 |
+
"faithfulness": evaluator.evaluate_faithfulness(response, context),
|
| 1830 |
+
"answer_correctness": 0.8, # Placeholder since we don't have ground truth for general questions
|
| 1831 |
+
"context_precision": evaluator.evaluate_context_precision(query, context),
|
| 1832 |
+
"context_recall": 0.8 # Placeholder
|
| 1833 |
+
}
|
| 1834 |
+
|
| 1835 |
+
return response, relevant_chunks, ragas_result
|
| 1836 |
|
| 1837 |
async def generate_groq_response(prompt, model_name="mixtral-8x7b-32768", conversation_history=None):
|
| 1838 |
try:
|
|
|
|
| 1945 |
return infer, names_infer, classes, confidence_scores, bounding_boxes
|
| 1946 |
|
| 1947 |
# Streamlit application
|
| 1948 |
+
st.sidebar.markdown("---")
|
| 1949 |
+
st.sidebar.header("🔬 Research Tools")
|
| 1950 |
+
|
| 1951 |
+
# Page selection
|
| 1952 |
+
page_selection = st.sidebar.radio(
|
| 1953 |
+
"Navigate to:",
|
| 1954 |
+
["🏠 Main App", "🔬 Research Dashboard"],
|
| 1955 |
+
index=0
|
| 1956 |
+
)
|
| 1957 |
|
| 1958 |
+
if page_selection == "🔬 Research Dashboard":
|
| 1959 |
+
render_research_page()
|
| 1960 |
+
else:
|
| 1961 |
+
# Your existing main app code
|
| 1962 |
+
st.title("Interactive Crop Disease Detection and Analysis🌾🌿🥬☘️")
|
| 1963 |
+
st.write(f"Welcome, {st.session_state['username']}!😊")
|
| 1964 |
+
|
| 1965 |
+
# Logout button
|
| 1966 |
+
if st.button("Logout"):
|
| 1967 |
+
logout()
|
| 1968 |
+
st.rerun()
|
| 1969 |
|
| 1970 |
# Add sidebar for configuration
|
| 1971 |
with st.sidebar:
|
|
|
|
| 2024 |
help="Select your preferred language"
|
| 2025 |
)
|
| 2026 |
|
| 2027 |
+
tab1, tab2 = st.tabs(["🖼️ Image Analysis", "💬 General Consultation"])
|
|
|
|
| 2028 |
|
| 2029 |
+
with tab1:
|
| 2030 |
+
st.header("Image-Based Disease Detection")
|
| 2031 |
+
st.write("Upload images of your crops to detect diseases and get specific analysis.")
|
| 2032 |
+
|
| 2033 |
+
# Main content - Image upload and analysis
|
| 2034 |
+
uploaded_files = st.file_uploader("Upload images for disease detection", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
| 2035 |
+
|
| 2036 |
+
if uploaded_files:
|
| 2037 |
+
for uploaded_file in uploaded_files:
|
| 2038 |
+
file_id = uploaded_file.name
|
| 2039 |
+
|
| 2040 |
+
# Initialize conversation history for this image if it doesn't exist
|
| 2041 |
+
if file_id not in st.session_state.conversation_history:
|
| 2042 |
+
st.session_state.conversation_history[file_id] = []
|
| 2043 |
+
|
| 2044 |
+
st.subheader(f"Analysis for {file_id}")
|
| 2045 |
+
|
| 2046 |
+
# Create columns for side-by-side display
|
| 2047 |
+
col1, col2 = st.columns(2)
|
| 2048 |
+
|
| 2049 |
+
# Process image
|
| 2050 |
+
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
|
| 2051 |
+
image = cv2.imdecode(file_bytes, 1)
|
| 2052 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 2053 |
+
|
| 2054 |
+
# Display original image
|
| 2055 |
+
with col1:
|
| 2056 |
+
st.subheader("Original Image")
|
| 2057 |
+
st.image(image)
|
| 2058 |
+
|
| 2059 |
+
# Process and display results
|
| 2060 |
+
with st.spinner("Processing image..."):
|
| 2061 |
+
infer_image, classes_in_image, classes_in_dataset, confidences, boxes = inference(image)
|
| 2062 |
+
|
| 2063 |
+
with col2:
|
| 2064 |
+
st.subheader("Detected Diseases")
|
| 2065 |
+
st.image(infer_image)
|
| 2066 |
+
|
| 2067 |
+
# Display detection details
|
| 2068 |
+
if show_confidence:
|
| 2069 |
+
st.subheader("Detection Details")
|
| 2070 |
+
for cls, conf in zip(classes_in_image, confidences):
|
| 2071 |
+
st.write(f"- {classes_in_dataset[cls]}: {conf:.2%} confidence")
|
| 2072 |
+
|
| 2073 |
+
# Display conversation history for this image
|
| 2074 |
+
if st.session_state.conversation_history[file_id]:
|
| 2075 |
+
st.subheader("Conversation History")
|
| 2076 |
+
for i, entry in enumerate(st.session_state.conversation_history[file_id]):
|
| 2077 |
+
question, response = entry[:2]
|
| 2078 |
|
| 2079 |
+
with st.expander(f"Q{i+1}: {question[:50]}...", expanded=False):
|
| 2080 |
+
st.write("**Question:**", question)
|
| 2081 |
+
st.write("**Response:**", response)
|
| 2082 |
+
|
| 2083 |
+
# Display feedback buttons and handle comment collection
|
| 2084 |
+
display_feedback_buttons(file_id, i, question, response)
|
| 2085 |
+
|
| 2086 |
+
# Audio playback option
|
| 2087 |
+
if st.button("🔊 Listen", key=f"listen_history_{file_id}_{i}"):
|
| 2088 |
+
with st.spinner("Generating audio..."):
|
| 2089 |
+
audio_bytes = text_to_speech(response, voice=selected_voice)
|
| 2090 |
+
if audio_bytes:
|
| 2091 |
+
st.audio(audio_bytes, format="audio/wav")
|
| 2092 |
+
|
| 2093 |
+
# User input for questions about the detected diseases
|
| 2094 |
+
st.subheader("Ask Questions About Detected Diseases")
|
| 2095 |
+
user_text = st.text_area(
|
| 2096 |
+
"Enter your question about the detected diseases:",
|
| 2097 |
+
placeholder="Example: What are the best treatment options for these diseases? What preventive measures should I take?",
|
| 2098 |
+
key=f"question_{file_id}"
|
| 2099 |
+
)
|
| 2100 |
+
|
| 2101 |
+
if st.button("Get Analysis", key=f"analyze_{file_id}"):
|
| 2102 |
+
with st.spinner(f"Generating analysis using {selected_model}..."):
|
| 2103 |
+
# Translate user input
|
| 2104 |
+
translated_input = asyncio.run(translator.translate(user_text, dest='en')).text
|
| 2105 |
+
st.write(f"Translated Input (to English): {translated_input}")
|
| 2106 |
+
|
| 2107 |
+
# Extract detected disease names
|
| 2108 |
+
detected_classes = [classes_in_dataset[cls] for cls in classes_in_image]
|
| 2109 |
|
| 2110 |
+
# Fetch reference answers for detected diseases
|
| 2111 |
+
reference_answers = []
|
| 2112 |
+
for disease_name in detected_classes:
|
| 2113 |
+
reference_answer = get_reference_answer(disease_name)
|
| 2114 |
+
if reference_answer:
|
| 2115 |
+
reference_answers.append(reference_answer)
|
| 2116 |
+
|
| 2117 |
+
# Combine reference answers into a single string
|
| 2118 |
+
reference_answer = "\n".join(reference_answers) if reference_answers else None
|
| 2119 |
+
|
| 2120 |
+
# Generate response with RAG
|
| 2121 |
+
response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response(
|
| 2122 |
+
translated_input,
|
| 2123 |
+
st.session_state.conversation_history[file_id],
|
| 2124 |
+
reference_answer # Pass the reference answer for evaluation
|
| 2125 |
+
))
|
| 2126 |
+
print("Response:", response)
|
| 2127 |
+
|
| 2128 |
+
if response is None:
|
| 2129 |
+
st.error("Failed to generate a response. Please try again.")
|
| 2130 |
+
response = "No response generated."
|
| 2131 |
+
|
| 2132 |
+
# Move the translate function call here
|
| 2133 |
+
if response:
|
| 2134 |
+
try:
|
| 2135 |
+
translated_response = asyncio.run(translator.translate(response, dest=language)).text
|
| 2136 |
+
except Exception as e:
|
| 2137 |
+
st.error(f"Translation failed: {e}")
|
| 2138 |
+
translated_response = response # Fallback to the original response
|
| 2139 |
+
else:
|
| 2140 |
+
translated_response = response
|
| 2141 |
+
|
| 2142 |
+
st.session_state.conversation_history[file_id].append((user_text, translated_response, None))
|
| 2143 |
+
|
| 2144 |
+
# Display the response and evaluation metrics
|
| 2145 |
+
#st.markdown("### Relevant Information")
|
| 2146 |
+
#for chunk in relevant_chunks:
|
| 2147 |
+
# st.write(f"- **Chunk {chunk['chunk_number']}**: {chunk['text']}")
|
| 2148 |
+
|
| 2149 |
+
st.markdown(response)
|
| 2150 |
+
|
| 2151 |
+
# Add audio playback option for the latest response
|
| 2152 |
+
col1, col2 = st.columns([1, 4])
|
| 2153 |
+
with col1:
|
| 2154 |
+
if st.button("🔊 Listen", key=f"listen_latest_{file_id}"):
|
| 2155 |
+
with st.spinner("Generating audio..."):
|
| 2156 |
+
audio_bytes = text_to_speech(response, language)
|
| 2157 |
+
if audio_bytes:
|
| 2158 |
+
st.audio(audio_bytes, format='audio/mp3')
|
| 2159 |
+
|
| 2160 |
+
with tab2:
|
| 2161 |
+
st.header("General Disease Consultation")
|
| 2162 |
+
st.write("Ask questions about crop diseases without uploading images. Get expert advice on plant pathology topics.")
|
| 2163 |
+
|
| 2164 |
+
# Initialize general consultation history
|
| 2165 |
+
if 'general_consultation' not in st.session_state.conversation_history:
|
| 2166 |
+
st.session_state.conversation_history['general_consultation'] = []
|
| 2167 |
+
|
| 2168 |
+
# Disease selection helper
|
| 2169 |
+
st.subheader("🎯 Quick Disease Lookup")
|
| 2170 |
+
col1, col2 = st.columns([2, 1])
|
| 2171 |
+
|
| 2172 |
+
with col1:
|
| 2173 |
+
# Get list of diseases from database for quick selection
|
| 2174 |
+
try:
|
| 2175 |
+
conn = sqlite3.connect('./db/disease_knowledge_base.db')
|
| 2176 |
+
c = conn.cursor()
|
| 2177 |
+
c.execute("SELECT DISTINCT name FROM diseases ORDER BY name")
|
| 2178 |
+
available_diseases = [row[0] for row in c.fetchall()]
|
| 2179 |
+
conn.close()
|
| 2180 |
+
except:
|
| 2181 |
+
available_diseases = ["Corn Leaf Blight", "Apple Scab", "Tomato Late Blight", "Wheat Rust"]
|
| 2182 |
|
| 2183 |
+
selected_disease = st.selectbox(
|
| 2184 |
+
"Select a specific disease for quick information:",
|
| 2185 |
+
[""] + available_diseases,
|
| 2186 |
+
help="Choose a disease to get instant information about it"
|
|
|
|
|
|
|
| 2187 |
)
|
| 2188 |
+
|
| 2189 |
+
with col2:
|
| 2190 |
+
if selected_disease and st.button("Get Disease Info", key="quick_disease_info"):
|
| 2191 |
+
with st.spinner("Retrieving disease information..."):
|
| 2192 |
+
quick_query = f"Tell me about {selected_disease} - its causes, symptoms, and treatment options."
|
| 2193 |
+
|
| 2194 |
+
# Generate response using RAG
|
| 2195 |
+
response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response_general(
|
| 2196 |
+
quick_query,
|
| 2197 |
+
st.session_state.conversation_history['general_consultation']
|
| 2198 |
+
))
|
| 2199 |
+
|
| 2200 |
+
# Translate if needed
|
| 2201 |
+
if language != 'en':
|
| 2202 |
+
try:
|
| 2203 |
+
translated_response = translator.translate(response, dest=language).text
|
| 2204 |
+
except:
|
| 2205 |
+
translated_response = response
|
| 2206 |
+
else:
|
| 2207 |
+
translated_response = response
|
| 2208 |
|
| 2209 |
+
# Add to conversation history
|
| 2210 |
+
st.session_state.conversation_history['general_consultation'].append((quick_query, translated_response))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2211 |
|
| 2212 |
+
st.markdown("### Disease Information")
|
| 2213 |
+
st.markdown(translated_response)
|
| 2214 |
+
|
| 2215 |
+
# Audio option
|
| 2216 |
+
if st.button("🔊 Listen to Response", key="listen_quick_disease"):
|
| 2217 |
+
with st.spinner("Generating audio..."):
|
| 2218 |
+
audio_bytes = text_to_speech(translated_response, voice=selected_voice)
|
| 2219 |
+
if audio_bytes:
|
| 2220 |
+
st.audio(audio_bytes, format="audio/wav")
|
| 2221 |
+
|
| 2222 |
+
# General question input
|
| 2223 |
+
st.subheader("💡 Ask Any Question About Crop Diseases")
|
| 2224 |
+
|
| 2225 |
+
# Provide example questions
|
| 2226 |
+
example_questions = [
|
| 2227 |
+
"What are the most common fungal diseases in tomatoes?",
|
| 2228 |
+
"How can I prevent wheat rust in my field?",
|
| 2229 |
+
"What's the difference between bacterial and viral plant diseases?",
|
| 2230 |
+
"Which organic treatments work best for aphid control?",
|
| 2231 |
+
"What are the early signs of nutrient deficiency in corn?",
|
| 2232 |
+
"How do weather conditions affect plant disease development?",
|
| 2233 |
+
]
|
| 2234 |
+
|
| 2235 |
+
with st.expander("💡 Example Questions", expanded=False):
|
| 2236 |
+
for i, example in enumerate(example_questions):
|
| 2237 |
+
if st.button(example, key=f"example_{i}"):
|
| 2238 |
+
st.session_state[f"general_question_input"] = example
|
| 2239 |
+
|
| 2240 |
+
general_question = st.text_area(
|
| 2241 |
+
"Enter your question about crop diseases, plant pathology, or agricultural practices:",
|
| 2242 |
+
placeholder="Example: What are the most effective organic methods to control powdery mildew in grapes?",
|
| 2243 |
+
key="general_question_input",
|
| 2244 |
+
height=100
|
| 2245 |
+
)
|
| 2246 |
+
|
| 2247 |
+
# Topic categories for better organization
|
| 2248 |
+
st.subheader("🏷️ Question Categories")
|
| 2249 |
+
col1, col2, col3 = st.columns(3)
|
| 2250 |
+
|
| 2251 |
+
with col1:
|
| 2252 |
+
if st.button("🦠 Disease Identification", key="cat_identification"):
|
| 2253 |
+
st.session_state["general_question_input"] = "How can I identify different types of plant diseases based on symptoms?"
|
| 2254 |
+
|
| 2255 |
+
with col2:
|
| 2256 |
+
if st.button("💊 Treatment Options", key="cat_treatment"):
|
| 2257 |
+
st.session_state["general_question_input"] = "What are the most effective treatment options for fungal plant diseases?"
|
| 2258 |
+
|
| 2259 |
+
with col3:
|
| 2260 |
+
if st.button("🛡️ Prevention Methods", key="cat_prevention"):
|
| 2261 |
+
st.session_state["general_question_input"] = "What preventive measures can I take to protect my crops from diseases?"
|
| 2262 |
+
|
| 2263 |
+
if st.button("Get Expert Answer", key="general_analyze", type="primary"):
|
| 2264 |
+
if general_question.strip():
|
| 2265 |
+
with st.spinner(f"Consulting plant pathology expert using {selected_model}..."):
|
| 2266 |
+
# Translate user input if needed
|
| 2267 |
+
if language != 'en':
|
| 2268 |
+
try:
|
| 2269 |
+
translated_input = translator.translate(general_question, dest='en').text
|
| 2270 |
+
st.info(f"Translated to English: {translated_input}")
|
| 2271 |
+
except:
|
| 2272 |
+
translated_input = general_question
|
| 2273 |
+
else:
|
| 2274 |
+
translated_input = general_question
|
| 2275 |
|
| 2276 |
+
# Generate response using RAG for general consultation
|
| 2277 |
+
response, relevant_chunks, ragas_result = asyncio.run(generate_rag_response_general(
|
| 2278 |
translated_input,
|
| 2279 |
+
st.session_state.conversation_history['general_consultation']
|
|
|
|
| 2280 |
))
|
|
|
|
| 2281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2282 |
if response:
|
| 2283 |
+
# Translate response back to user's language
|
| 2284 |
+
if language != 'en':
|
| 2285 |
+
try:
|
| 2286 |
+
translated_response = translator.translate(response, dest=language).text
|
| 2287 |
+
except Exception as e:
|
| 2288 |
+
st.error(f"Translation failed: {e}")
|
| 2289 |
+
translated_response = response
|
| 2290 |
+
else:
|
| 2291 |
+
translated_response = response
|
| 2292 |
|
| 2293 |
+
# Add to conversation history
|
| 2294 |
+
st.session_state.conversation_history['general_consultation'].append((general_question, translated_response))
|
| 2295 |
|
| 2296 |
+
# Display response
|
| 2297 |
+
st.markdown("### Expert Response")
|
| 2298 |
+
st.markdown(translated_response)
|
| 2299 |
+
|
| 2300 |
+
# Show relevant sources if available
|
| 2301 |
+
if relevant_chunks:
|
| 2302 |
+
with st.expander("📚 Information Sources", expanded=False):
|
| 2303 |
+
for i, chunk in enumerate(relevant_chunks[:3]): # Show top 3 sources
|
| 2304 |
+
st.write(f"**Source {i+1}:** {chunk['text'][:200]}...")
|
| 2305 |
|
| 2306 |
+
# Audio playback option
|
| 2307 |
+
col1, col2 = st.columns([1, 4])
|
| 2308 |
+
with col1:
|
| 2309 |
+
if st.button("🔊 Listen", key="listen_general_latest"):
|
| 2310 |
+
with st.spinner("Generating audio..."):
|
| 2311 |
+
audio_bytes = text_to_speech(translated_response, voice=selected_voice)
|
| 2312 |
+
if audio_bytes:
|
| 2313 |
+
st.audio(audio_bytes, format="audio/wav")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2314 |
else:
|
| 2315 |
+
st.error("Failed to generate a response. Please try again.")
|
| 2316 |
+
else:
|
| 2317 |
+
st.warning("Please enter a question before submitting.")
|
| 2318 |
+
|
| 2319 |
+
# Display general consultation history
|
| 2320 |
+
if st.session_state.conversation_history['general_consultation']:
|
| 2321 |
+
st.subheader("📝 Consultation History")
|
| 2322 |
+
for i, entry in enumerate(st.session_state.conversation_history['general_consultation']):
|
| 2323 |
+
question, response = entry[:2]
|
| 2324 |
+
|
| 2325 |
+
with st.expander(f"Q{i+1}: {question[:60]}...", expanded=False):
|
| 2326 |
+
st.write("**Question:**", question)
|
| 2327 |
+
st.write("**Response:**", response)
|
| 2328 |
|
| 2329 |
+
# Feedback buttons for general consultation
|
| 2330 |
+
display_feedback_buttons('general_consultation', i, question, response)
|
| 2331 |
+
|
| 2332 |
+
# Audio playback for history
|
| 2333 |
+
if st.button("🔊 Listen", key=f"listen_general_history_{i}"):
|
| 2334 |
+
with st.spinner("Generating audio..."):
|
| 2335 |
+
audio_bytes = text_to_speech(response, voice=selected_voice)
|
| 2336 |
+
if audio_bytes:
|
| 2337 |
+
st.audio(audio_bytes, format="audio/wav")
|
| 2338 |
+
|
| 2339 |
+
# Export general consultation
|
| 2340 |
+
if st.session_state.conversation_history['general_consultation']:
|
| 2341 |
+
if st.button("📄 Export Consultation", key="export_general"):
|
| 2342 |
+
consultation_text = f"""
|
| 2343 |
+
# General Crop Disease Consultation Report
|
| 2344 |
+
|
| 2345 |
+
## Consultation Information
|
| 2346 |
+
- Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 2347 |
+
- Language: {language}
|
| 2348 |
+
- Model Used: {selected_model}
|
| 2349 |
+
|
| 2350 |
+
## Consultation History
|
| 2351 |
+
"""
|
| 2352 |
+
|
| 2353 |
+
for i, entry in enumerate(st.session_state.conversation_history['general_consultation']):
|
| 2354 |
+
question, response = entry[:2]
|
| 2355 |
+
consultation_text += f"\n### Question {i+1}:\n{question}\n\n### Expert Response {i+1}:\n{response}\n\n---\n"
|
| 2356 |
|
| 2357 |
st.download_button(
|
| 2358 |
+
label="📥 Download Consultation Report",
|
| 2359 |
+
data=consultation_text,
|
| 2360 |
+
file_name=f"crop_disease_consultation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
|
| 2361 |
+
mime="text/markdown",
|
| 2362 |
+
key="download_general"
|
| 2363 |
)
|
| 2364 |
+
|
| 2365 |
# Add a footer with clear instructions
|
| 2366 |
st.markdown("""
|
| 2367 |
---
|