Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import torch | |
| import re | |
| from openai import OpenAI | |
| from transformers import AutoTokenizer, AutoModel | |
| # Constants | |
| # Available OpenAI models: gpt-3.5-turbo, gpt-4, gpt-4-turbo | |
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| OPENAI_MODEL = "gpt-3.5-turbo" # default | |
| MAX_TOKENS = 300 | |
| TOP_K_RESULTS = 3 | |
| DATA_DIR = "data" | |
| LOG_FILE = os.path.join("data", "missing_info_log.txt") | |
| # Environment check and OpenAI init | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable not set") | |
| client = OpenAI(api_key=openai_api_key) | |
| # Load embedding model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| transformer_model = AutoModel.from_pretrained(MODEL_NAME) | |
| # File check | |
| if not all(os.path.exists(os.path.join(DATA_DIR, f"plant_{f}")) for f in ["index.faiss", "doc_map.csv"]): | |
| raise FileNotFoundError(f"Missing data files in {DATA_DIR} directory") | |
| # Load FAISS index and doc map | |
| index = faiss.read_index(f"{DATA_DIR}/plant_index.faiss") | |
| doc_map = pd.read_csv(f"{DATA_DIR}/plant_doc_map.csv") | |
| # Mean pooling to ignore padding tokens | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) | |
| # Embed text using mean pooling | |
| def embed_text(texts): | |
| inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| model_output = transformer_model(**inputs) | |
| embeddings = mean_pooling(model_output, inputs['attention_mask']) | |
| return embeddings.cpu().numpy() | |
| # Extract image or gallery URLs | |
| def extract_images_from_text(text): | |
| urls = re.findall(r'https?://[^\s)]+', text) | |
| direct_images = [url for url in urls if re.search(r'\.(jpg|jpeg|png|gif)$', url, re.IGNORECASE)] | |
| if not direct_images: | |
| gallery_links = [url for url in urls if "invasive.org" in url or "plants.usda.gov" in url] | |
| return list(set(gallery_links)) | |
| return list(set(direct_images)) | |
| # Format for display with distinction between galleries and images | |
| def format_image_response(image_links): | |
| if not image_links: | |
| return "I couldn't find specific images for this plant. But we are working on it! In the meantime, feel free to use the `Invasive Plant Image Search` bar under Tools." | |
| formatted_links = [] | |
| gallery_link = None | |
| for link in image_links: | |
| if "subthumb.cfm" in link: | |
| gallery_link = link | |
| elif "detail.cfm" in link or re.search(r'\.(jpg|jpeg|png|gif)$', link, re.IGNORECASE): | |
| formatted_links.append(f"[📷 View Image]({link})") | |
| else: | |
| formatted_links.append(f"[🔗 View Resource]({link})") | |
| if gallery_link: | |
| formatted_links.append(f"[🖼️ View Gallery]({gallery_link})") | |
| formatted_links.append("You can also check the full gallery above for more pictures.") | |
| return "\n".join(formatted_links) # fixed newline join | |
| # Improved text filtering | |
| def filter_text_response(text): | |
| if not text: | |
| return "I don't have information about that specific query." | |
| paragraphs = re.split(r'\n\n+', text) | |
| return paragraphs[0].strip() | |
| # Log missing information to a file | |
| def log_missing_info(message, response, log_path=LOG_FILE): | |
| if "I don't have that information" in response: | |
| with open(log_path, "a") as log_file: | |
| log_file.write(f"USER QUERY: {message.strip()}\n\n") | |
| return "I don't have that information. But your request has been logged and will be provided in a future update." | |
| return response | |
| # RAG chatbot logic | |
| cache = {} | |
| def respond_with_model(message, chat_history, model_name): | |
| if message in cache: | |
| bot_message = cache[message] | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| query_embedding = embed_text([message])[0] | |
| distances, indices = index.search(np.array([query_embedding]).astype(np.float32), k=TOP_K_RESULTS) | |
| retrieved_docs = [doc_map.iloc[i]["rag_document"] for i in indices[0]] | |
| is_image_request = any(word in message.lower() for word in ["image", "picture", "photo", "show me"]) | |
| if is_image_request: | |
| image_links = [] | |
| for doc in retrieved_docs: | |
| image_links.extend(extract_images_from_text(doc)) | |
| context = format_image_response(image_links) | |
| else: | |
| context = "\n\n".join([filter_text_response(doc) for doc in retrieved_docs[:2]]) | |
| messages = [ | |
| {"role": "system", "content": """You are a highly precise botany expert answering questions about invasive plants. Adhere strictly to these guidelines: | |
| 1. Answer exactly what is asked—no extra details, explanations, or tangents. | |
| 2. For image requests, provide ONLY markdown image links. | |
| 3. Be factual and concise: Use minimal words without sacrificing accuracy. Cite sources if available. | |
| 4. Uncertainty response: If you lack verified information, say, "I don’t have that information." No speculation. | |
| 5. Never invent facts: Prioritize accuracy over completeness."""} | |
| ] | |
| for user_msg, bot_msg in chat_history[-3:]: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| messages.append({"role": "user", "content": f"{message}\n\nRelevant context:\n{context}"}) | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=messages, | |
| max_tokens=512, # Caps response length | |
| temperature=0.3, # Controls randomness. Lower values make outputs more deterministic | |
| top_p=0.9, # Nucleus sampling: Limits token selection to a probability mass (e.g., top 90% likely tokens) | |
| frequency_penalty=0.2, # Slightly increase (e.g., 0.1–0.5) to discourage word repetition | |
| presence_penalty=0.1 # Increase (e.g., 0.2) to discourage revisiting topics in long responses | |
| ) | |
| bot_message = response.choices[0].message.content | |
| if is_image_request: | |
| bot_message = format_image_response(extract_images_from_text(bot_message)) | |
| else: | |
| bot_message = re.sub(r'!\[.*?\]\(.*?\)', '', bot_message).strip() | |
| bot_message = log_missing_info(message, bot_message) | |
| cache[message] = bot_message | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| # Gradio UI | |
| def create_interface(): | |
| with gr.Blocks( | |
| theme=gr.themes.Default( | |
| primary_hue="emerald", | |
| font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"] | |
| ), | |
| title="\U0001F33F Plant Identification & Invasiveness Checker" | |
| ) as demo: | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| <div style='text-align: center; margin-bottom: 10px;'> | |
| <h1 style='margin-bottom: 5px;'>\U0001F33F Conservation Made Easy: Identify Invasive Plants</h1> | |
| <p>Provide either the common name or scientific name of a plant to receive precise details about its ecological impact, | |
| growth characteristics, and management recommendations. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_selector = gr.Dropdown( | |
| choices=["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo"], | |
| value="gpt-3.5-turbo", | |
| label="Choose OpenAI Model" | |
| ) | |
| chatbot = gr.Chatbot(height=400, bubble_full_width=False, show_label=False) | |
| msg = gr.Textbox(placeholder="Ask questions about any plant (e.g. 'Show me pictures of wine grape')", label="Plant Question", container=False) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### \U0001F4CB Question Templates") | |
| with gr.Tab("Identification"): | |
| gr.Examples([ | |
| ["What is the common name of Hypericum calycinum?"], | |
| ["What's the scientific name for japanese honeysuckle?"] | |
| ], inputs=msg, label="Name/Symbol Questions") | |
| with gr.Tab("Appearance"): | |
| gr.Examples([ | |
| ["Describe the appearance of african mustard."], | |
| ["Describe yellowtuft's foliage."], | |
| ["What color are korean lespedeza flowers?"], | |
| ["Does carrotwood produce fruit?"], | |
| ["What is the growth habit of vinegartree?"], | |
| ["Can I see images of scots pine?"] | |
| ], inputs=msg, label="Physical Characteristics") | |
| with gr.Tab("Status"): | |
| gr.Examples([ | |
| ["Is tawny daylily invasive in Michigan?"], | |
| ["Noxious weed status for giant hogweed"], | |
| ["Ecological threat of african asparagus fern"] | |
| ], inputs=msg, label="Invasiveness Questions") | |
| gr.Markdown("---") | |
| with gr.Accordion("\U0001F50D Search Tools", open=False): | |
| gr.Button("\U0001F33F USDA Plants Database", link="https://plants.usda.gov/home") | |
| gr.Button("\U0001F4F8 Invasive Plant Image Search", link="https://www.invasive.org/images.cfm") | |
| msg.submit(lambda message, chat, model: respond_with_model(message, chat, model), [msg, chatbot, model_selector], [msg, chatbot]) | |
| submit_btn.click(lambda message, chat, model: respond_with_model(message, chat, model), [msg, chatbot, model_selector], [msg, chatbot]) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| return demo | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch(server_name="0.0.0.0", server_port=7860) # Enabled public sharing, share = True |