jcantu217's picture
Update app.py
8851610 verified
import os
import gradio as gr
import pandas as pd
import numpy as np
import faiss
import torch
import re
from openai import OpenAI
from transformers import AutoTokenizer, AutoModel
# Constants
# Available OpenAI models: gpt-3.5-turbo, gpt-4, gpt-4-turbo
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
OPENAI_MODEL = "gpt-3.5-turbo" # default
MAX_TOKENS = 300
TOP_K_RESULTS = 3
DATA_DIR = "data"
LOG_FILE = os.path.join("data", "missing_info_log.txt")
# Environment check and OpenAI init
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY environment variable not set")
client = OpenAI(api_key=openai_api_key)
# Load embedding model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
transformer_model = AutoModel.from_pretrained(MODEL_NAME)
# File check
if not all(os.path.exists(os.path.join(DATA_DIR, f"plant_{f}")) for f in ["index.faiss", "doc_map.csv"]):
raise FileNotFoundError(f"Missing data files in {DATA_DIR} directory")
# Load FAISS index and doc map
index = faiss.read_index(f"{DATA_DIR}/plant_index.faiss")
doc_map = pd.read_csv(f"{DATA_DIR}/plant_doc_map.csv")
# Mean pooling to ignore padding tokens
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
# Embed text using mean pooling
def embed_text(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
model_output = transformer_model(**inputs)
embeddings = mean_pooling(model_output, inputs['attention_mask'])
return embeddings.cpu().numpy()
# Extract image or gallery URLs
def extract_images_from_text(text):
urls = re.findall(r'https?://[^\s)]+', text)
direct_images = [url for url in urls if re.search(r'\.(jpg|jpeg|png|gif)$', url, re.IGNORECASE)]
if not direct_images:
gallery_links = [url for url in urls if "invasive.org" in url or "plants.usda.gov" in url]
return list(set(gallery_links))
return list(set(direct_images))
# Format for display with distinction between galleries and images
def format_image_response(image_links):
if not image_links:
return "I couldn't find specific images for this plant. But we are working on it! In the meantime, feel free to use the `Invasive Plant Image Search` bar under Tools."
formatted_links = []
gallery_link = None
for link in image_links:
if "subthumb.cfm" in link:
gallery_link = link
elif "detail.cfm" in link or re.search(r'\.(jpg|jpeg|png|gif)$', link, re.IGNORECASE):
formatted_links.append(f"[📷 View Image]({link})")
else:
formatted_links.append(f"[🔗 View Resource]({link})")
if gallery_link:
formatted_links.append(f"[🖼️ View Gallery]({gallery_link})")
formatted_links.append("You can also check the full gallery above for more pictures.")
return "\n".join(formatted_links) # fixed newline join
# Improved text filtering
def filter_text_response(text):
if not text:
return "I don't have information about that specific query."
paragraphs = re.split(r'\n\n+', text)
return paragraphs[0].strip()
# Log missing information to a file
def log_missing_info(message, response, log_path=LOG_FILE):
if "I don't have that information" in response:
with open(log_path, "a") as log_file:
log_file.write(f"USER QUERY: {message.strip()}\n\n")
return "I don't have that information. But your request has been logged and will be provided in a future update."
return response
# RAG chatbot logic
cache = {}
def respond_with_model(message, chat_history, model_name):
if message in cache:
bot_message = cache[message]
chat_history.append((message, bot_message))
return "", chat_history
query_embedding = embed_text([message])[0]
distances, indices = index.search(np.array([query_embedding]).astype(np.float32), k=TOP_K_RESULTS)
retrieved_docs = [doc_map.iloc[i]["rag_document"] for i in indices[0]]
is_image_request = any(word in message.lower() for word in ["image", "picture", "photo", "show me"])
if is_image_request:
image_links = []
for doc in retrieved_docs:
image_links.extend(extract_images_from_text(doc))
context = format_image_response(image_links)
else:
context = "\n\n".join([filter_text_response(doc) for doc in retrieved_docs[:2]])
messages = [
{"role": "system", "content": """You are a highly precise botany expert answering questions about invasive plants. Adhere strictly to these guidelines:
1. Answer exactly what is asked—no extra details, explanations, or tangents.
2. For image requests, provide ONLY markdown image links.
3. Be factual and concise: Use minimal words without sacrificing accuracy. Cite sources if available.
4. Uncertainty response: If you lack verified information, say, "I don’t have that information." No speculation.
5. Never invent facts: Prioritize accuracy over completeness."""}
]
for user_msg, bot_msg in chat_history[-3:]:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"{message}\n\nRelevant context:\n{context}"})
response = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=512, # Caps response length
temperature=0.3, # Controls randomness. Lower values make outputs more deterministic
top_p=0.9, # Nucleus sampling: Limits token selection to a probability mass (e.g., top 90% likely tokens)
frequency_penalty=0.2, # Slightly increase (e.g., 0.1–0.5) to discourage word repetition
presence_penalty=0.1 # Increase (e.g., 0.2) to discourage revisiting topics in long responses
)
bot_message = response.choices[0].message.content
if is_image_request:
bot_message = format_image_response(extract_images_from_text(bot_message))
else:
bot_message = re.sub(r'!\[.*?\]\(.*?\)', '', bot_message).strip()
bot_message = log_missing_info(message, bot_message)
cache[message] = bot_message
chat_history.append((message, bot_message))
return "", chat_history
# Gradio UI
def create_interface():
with gr.Blocks(
theme=gr.themes.Default(
primary_hue="emerald",
font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"]
),
title="\U0001F33F Plant Identification & Invasiveness Checker"
) as demo:
with gr.Row():
gr.Markdown("""
<div style='text-align: center; margin-bottom: 10px;'>
<h1 style='margin-bottom: 5px;'>\U0001F33F Conservation Made Easy: Identify Invasive Plants</h1>
<p>Provide either the common name or scientific name of a plant to receive precise details about its ecological impact,
growth characteristics, and management recommendations.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=3):
model_selector = gr.Dropdown(
choices=["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo"],
value="gpt-3.5-turbo",
label="Choose OpenAI Model"
)
chatbot = gr.Chatbot(height=400, bubble_full_width=False, show_label=False)
msg = gr.Textbox(placeholder="Ask questions about any plant (e.g. 'Show me pictures of wine grape')", label="Plant Question", container=False)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=1):
gr.Markdown("### \U0001F4CB Question Templates")
with gr.Tab("Identification"):
gr.Examples([
["What is the common name of Hypericum calycinum?"],
["What's the scientific name for japanese honeysuckle?"]
], inputs=msg, label="Name/Symbol Questions")
with gr.Tab("Appearance"):
gr.Examples([
["Describe the appearance of african mustard."],
["Describe yellowtuft's foliage."],
["What color are korean lespedeza flowers?"],
["Does carrotwood produce fruit?"],
["What is the growth habit of vinegartree?"],
["Can I see images of scots pine?"]
], inputs=msg, label="Physical Characteristics")
with gr.Tab("Status"):
gr.Examples([
["Is tawny daylily invasive in Michigan?"],
["Noxious weed status for giant hogweed"],
["Ecological threat of african asparagus fern"]
], inputs=msg, label="Invasiveness Questions")
gr.Markdown("---")
with gr.Accordion("\U0001F50D Search Tools", open=False):
gr.Button("\U0001F33F USDA Plants Database", link="https://plants.usda.gov/home")
gr.Button("\U0001F4F8 Invasive Plant Image Search", link="https://www.invasive.org/images.cfm")
msg.submit(lambda message, chat, model: respond_with_model(message, chat, model), [msg, chatbot, model_selector], [msg, chatbot])
submit_btn.click(lambda message, chat, model: respond_with_model(message, chat, model), [msg, chatbot, model_selector], [msg, chatbot])
clear_btn.click(lambda: None, None, chatbot, queue=False)
return demo
if __name__ == "__main__":
app = create_interface()
app.launch(server_name="0.0.0.0", server_port=7860) # Enabled public sharing, share = True