Spaces:
Sleeping
Sleeping
| import os | |
| import faiss | |
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| import torch | |
| import pandas as pd | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| import google.generativeai as genai | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Set cache directory for Hugging Face models (SciBERT only) | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| # Get Gemini API key from environment variable (stored in Spaces secrets) | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| logger.error("GEMINI_API_KEY not set. Please set it in Hugging Face Spaces secrets.") | |
| raise ValueError("GEMINI_API_KEY is required for Gemini API access.") | |
| try: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| logger.info("Gemini API configured") | |
| except Exception as e: | |
| logger.error(f"Failed to configure Gemini API: {e}") | |
| raise | |
| # Load dataset with error handling | |
| DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json") | |
| try: | |
| if not os.path.exists(DATASET_PATH): | |
| raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}") | |
| df = pd.read_json(DATASET_PATH) | |
| logger.info("Dataset loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load dataset: {e}") | |
| raise | |
| # Clean text | |
| def clean_text(text): | |
| return text.strip().lower() if isinstance(text, str) else "" | |
| try: | |
| df["cleaned_abstract"] = df["abstract"].apply(clean_text) | |
| logger.info("Text cleaning completed") | |
| except Exception as e: | |
| logger.error(f"Error during cleaning abstracts: {e}") | |
| raise | |
| # Precompute BM25 Index | |
| try: | |
| tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]] | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| logger.info("BM25 index created") | |
| except Exception as e: | |
| logger.error(f"BM25 index creation failed: {e}") | |
| raise | |
| # Load SciBERT for embeddings | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| try: | |
| sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface") | |
| sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface") | |
| sci_bert_model.to(device) | |
| sci_bert_model.eval() | |
| logger.info("SciBERT loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {e}") | |
| raise | |
| # Generate SciBERT embeddings | |
| def generate_embeddings_sci_bert(texts, batch_size=32): | |
| all_embeddings = [] | |
| try: | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = sci_bert_model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| all_embeddings.append(embeddings.cpu().numpy()) | |
| torch.cuda.empty_cache() | |
| return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 768)) | |
| except Exception as e: | |
| logger.error(f"Embedding generation failed: {e}") | |
| return np.zeros((len(texts), 768)) | |
| try: | |
| abstracts = df["cleaned_abstract"].tolist() | |
| embeddings = generate_embeddings_sci_bert(abstracts) | |
| if embeddings.shape[0] != len(abstracts): | |
| logger.warning("Embeddings count does not match abstracts count") | |
| dimension = embeddings.shape[1] if embeddings.size else 768 | |
| faiss_index = faiss.IndexFlatL2(dimension) | |
| if embeddings.size: | |
| faiss_index.add(embeddings.astype(np.float32)) | |
| logger.info("FAISS index created") | |
| else: | |
| logger.warning("No embeddings to index") | |
| except Exception as e: | |
| logger.error(f"FAISS index creation failed: {e}") | |
| raise | |
| def get_relevant_papers(query): | |
| if not isinstance(query, str) or not query.strip(): | |
| return [], [], "Please enter a valid search query." | |
| try: | |
| query_embedding = generate_embeddings_sci_bert([query]) | |
| distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5) | |
| tokenized_query = query.lower().split() | |
| bm25_scores = bm25.get_scores(tokenized_query) | |
| bm25_top_indices = np.argsort(bm25_scores)[::-1][:5] | |
| combined_indices = list(set(indices[0]) | set(bm25_top_indices)) | |
| ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx]) | |
| papers = [] | |
| for i, idx in enumerate(ranked_results[:5]): | |
| try: | |
| title = df.iloc[idx]['title'] | |
| abstract_snip = df.iloc[idx]['abstract'][:200] + "..." if len(df.iloc[idx]['abstract']) > 200 else df.iloc[idx]['abstract'] | |
| papers.append(f"{i+1}. {title} - Abstract: {abstract_snip}") | |
| except Exception as e: | |
| logger.error(f"Error accessing paper at index {idx}: {e}") | |
| return papers, ranked_results[:5], "Search completed." | |
| except Exception as e: | |
| logger.error(f"Search failed: {e}") | |
| return [], [], "Search failed. Please try again." | |
| def answer_question(selected_index, question, history): | |
| if selected_index is None: | |
| return [(question, "Please select a paper first!")], history | |
| if not isinstance(question, str) or not question.strip(): | |
| return [(question, "Please ask a question!")], history | |
| if question.lower() in ["exit", "done"]: | |
| return [("Conversation ended.", "Select a new paper or search again!")], [] | |
| try: | |
| paper_data = df.iloc[selected_index] | |
| title = paper_data.get("title", "Unknown Title") | |
| abstract = paper_data.get("abstract", "Abstract not available.") | |
| authors_list = paper_data.get("authors", []) | |
| authors = ", ".join(authors_list) if isinstance(authors_list, list) else str(authors_list) | |
| doi = paper_data.get("doi", "No DOI") | |
| prompt = ( | |
| "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. " | |
| "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's details. " | |
| "When asked about tech stacks or methods, follow these guidelines:\n" | |
| "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n" | |
| "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n" | |
| "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n" | |
| "4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n" | |
| "5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n" | |
| "6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n" | |
| f"Here’s the paper:\n" | |
| f"Title: {title}\n" | |
| f"Authors: {authors}\n" | |
| f"Abstract: {abstract}\n" | |
| f"DOI: {doi}\n\n" | |
| ) | |
| if history: | |
| prompt += "Previous conversation (use for context):\n" | |
| for user_q, bot_a in history[-2:]: | |
| prompt += f"User: {user_q}\nAssistant: {bot_a}\n" | |
| prompt += f"Now, answer this question: {question}" | |
| logger.info(f"Prompt sent to Gemini API (truncated): {prompt[:500]}...") | |
| # Updated to use valid model name | |
| model = genai.GenerativeModel("models/gemini-2.5-flash") | |
| response = model.generate_content(prompt) | |
| answer = getattr(response, 'text', '').strip() if response else "" | |
| if not answer or len(answer) < 15: | |
| logger.warning("Received short or empty answer from Gemini API, applying fallback.") | |
| answer = ( | |
| "The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n" | |
| "- Python: Core language for ML/DL.\n" | |
| "- TensorFlow or PyTorch: Frameworks for deep learning models.\n" | |
| "- Scikit-learn: For traditional ML algorithms.\n" | |
| "- Pandas/NumPy: For data handling and preprocessing." | |
| ) | |
| history.append((question, answer)) | |
| return history, history | |
| except Exception as e: | |
| logger.error(f"QA failed: {e}") | |
| history.append((question, "Sorry, I couldn’t process that. Try again!")) | |
| return history, history | |
| # Gradio UI | |
| with gr.Blocks( | |
| css=""" | |
| .chatbot {height: 600px; overflow-y: auto;} | |
| .sidebar {width: 300px;} | |
| #main {display: flex; flex-direction: row;} | |
| """, | |
| theme=gr.themes.Default(primary_hue="blue") | |
| ) as demo: | |
| gr.Markdown("# ResearchGPT - Paper Search & Chat") | |
| with gr.Row(elem_id="main"): | |
| # Sidebar for search | |
| with gr.Column(scale=1, min_width=300, elem_classes="sidebar"): | |
| gr.Markdown("### Search Papers") | |
| query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare") | |
| search_btn = gr.Button("Search") | |
| paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True) | |
| search_status = gr.Textbox(label="Search Status", interactive=False) | |
| # States to store paper choices and indices | |
| paper_choices_state = gr.State([]) | |
| paper_indices_state = gr.State([]) | |
| search_btn.click( | |
| fn=get_relevant_papers, | |
| inputs=query_input, | |
| outputs=[paper_choices_state, paper_indices_state, search_status] | |
| ).then( | |
| fn=lambda choices: gr.update(choices=choices, value=None), | |
| inputs=paper_choices_state, | |
| outputs=paper_dropdown | |
| ) | |
| # Main chat area | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Chat with Selected Paper") | |
| selected_paper = gr.Textbox(label="Selected Paper", interactive=False) | |
| chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot") | |
| question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?") | |
| chat_btn = gr.Button("Send") | |
| # State to store conversation history and selected index | |
| history_state = gr.State([]) | |
| selected_index_state = gr.State(None) | |
| # Update selected paper and index | |
| def update_selected_paper(choice, indices): | |
| if choice is None: | |
| return "", None | |
| try: | |
| index = int(choice.split(".")[0]) - 1 # Extract rank (e.g., "1." -> 0) | |
| selected_idx = indices[index] | |
| except Exception as e: | |
| logger.error(f"Error updating selected paper: {e}") | |
| return "", None | |
| return choice, selected_idx | |
| paper_dropdown.change( | |
| fn=update_selected_paper, | |
| inputs=[paper_dropdown, paper_indices_state], | |
| outputs=[selected_paper, selected_index_state] | |
| ).then( | |
| fn=lambda: [], | |
| inputs=None, | |
| outputs=chatbot | |
| ) | |
| # Handle chat | |
| chat_btn.click( | |
| fn=answer_question, | |
| inputs=[selected_index_state, question_input, history_state], | |
| outputs=[chatbot, history_state] | |
| ).then( | |
| fn=lambda: "", | |
| inputs=None, | |
| outputs=question_input | |
| ) | |
| # Launch the app | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |