import gradio as gr import os import json import re from typing import Iterator, Dict, Any, List, Optional from openai import OpenAI from openai.types.chat import ChatCompletionChunk # Load abstracts content once at startup def load_abstracts_content(): """Load the abstracts content once at startup to avoid reading file on every request.""" try: with open("abstracts.md", "r", encoding="utf-8") as f: return f.read() except FileNotFoundError: return "Abstracts database not found." # Load abstracts content globally ABSTRACTS_CONTENT = load_abstracts_content() # Load full paper texts def load_paper_texts(): """Load all paper texts from the Papers directory and create a mapping from abstracts filenames.""" papers = {} papers_dir = "Papers" if not os.path.exists(papers_dir): return {} # Create a mapping from abstracts filenames to actual file content for filename in os.listdir(papers_dir): if filename.endswith('.txt'): filepath = os.path.join(papers_dir, filename) try: with open(filepath, "r", encoding="utf-8") as f: content = f.read() # Store with the filename as key papers[filename] = content except Exception as e: papers[filename] = f"Error loading paper: {str(e)}" return papers # Load paper texts globally PAPER_TEXTS = load_paper_texts() def normalize_filename(filename): """Normalize filename for better matching.""" # Remove .txt extension and normalize if filename.endswith('.txt'): filename = filename[:-4] # Convert to lowercase and remove special characters filename = re.sub(r'[^\w\s]', '', filename.lower()) # Normalize whitespace filename = ' '.join(filename.split()) return filename def find_matching_paper_file(query_terms, papers_dict): """Find the best matching paper file based on query terms.""" query_normalized = normalize_filename(' '.join(query_terms)) best_match = None best_score = 0 for filename in papers_dict.keys(): filename_normalized = normalize_filename(filename) # Calculate match score score = 0 # Exact substring match if query_normalized in filename_normalized or filename_normalized in query_normalized: score += 10 # Word overlap query_words = set(query_normalized.split()) filename_words = set(filename_normalized.split()) overlap = len(query_words.intersection(filename_words)) score += overlap * 2 # Partial word matches for query_word in query_words: for filename_word in filename_words: if query_word in filename_word or filename_word in query_word: score += 1 if score > best_score: best_score = score best_match = filename return best_match if best_score > 0 else None def get_relevant_papers_content(query, max_papers=5): """Get relevant paper content based on user query.""" query_terms = query.lower().split() relevant_papers = [] for filename, content in PAPER_TEXTS.items(): title = filename[:-4] if filename.endswith('.txt') else filename title_lower = title.lower() # Calculate relevance score score = 0 for term in query_terms: if term in title_lower: score += 2 if term in content.lower(): score += 1 if score > 0: relevant_papers.append((filename, content, score)) # Sort by relevance score and return top papers relevant_papers.sort(key=lambda x: x[2], reverse=True) return relevant_papers[:max_papers] def get_full_paper_content(title, max_chars=12000): """Get full paper content for a specific title.""" for filename, content in PAPER_TEXTS.items(): if title.lower() in filename.lower() or filename.lower() in title.lower(): return content[:max_chars] + "..." if len(content) > max_chars else content return "Paper not found." def get_paper_summary(title): """Get a structured summary of a paper.""" content = get_full_paper_content(title) if content == "Paper not found.": return content # Extract key sections sections = { 'abstract': '', 'introduction': '', 'methodology': '', 'results': '', 'conclusions': '' } lines = content.split('\n') current_section = None for line in lines: line_lower = line.lower().strip() # Detect section headers if any(keyword in line_lower for keyword in ['abstract', 'introduction', 'method', 'methodology', 'results', 'conclusion']): if 'abstract' in line_lower: current_section = 'abstract' elif 'introduction' in line_lower: current_section = 'introduction' elif 'method' in line_lower: current_section = 'methodology' elif 'result' in line_lower: current_section = 'results' elif 'conclusion' in line_lower: current_section = 'conclusions' # Add content to current section if current_section and line.strip(): sections[current_section] += line + '\n' # Create structured summary summary = f"# {title}\n\n" for section, content in sections.items(): if content.strip(): summary += f"## {section.title()}\n{content.strip()}\n\n" return summary # Get API key with better error handling api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("⚠️ Warning: OPENAI_API_KEY environment variable not set!") client = None else: client = OpenAI( api_key=api_key, timeout=60.0, max_retries=3 ) # Available models AVAILABLE_MODELS = { "GPT-4o-mini": "gpt-4o-mini", "GPT-4o": "gpt-4o", "GPT-3.5 Turbo": "gpt-3.5-turbo" } # Define the tool for fetching papers FETCH_PAPERS_TOOL = { "type": "function", "function": { "name": "fetch_papers", "description": "Fetch full text content of research papers by their filenames. Use this when you need detailed information, full text, conclusions, methodology, or specific quotes from papers.", "parameters": { "type": "object", "properties": { "filenames": { "type": "array", "items": { "type": "string" }, "description": "List of paper filenames to fetch (e.g., ['The Labor Market Effects of Generativ.txt', 'AI Companions Reduce Loneliness.txt'])" } }, "required": ["filenames"] } } } def fetch_papers(filenames: List[str]) -> Dict[str, str]: """ Fetch full paper texts by filenames. Returns a dictionary mapping filename to content. """ papers = {} papers_dir = "Papers" if not os.path.exists(papers_dir): return {"error": "Papers directory not found"} for filename in filenames: # Ensure .txt extension if not filename.endswith('.txt'): filename += '.txt' filepath = os.path.join(papers_dir, filename) if os.path.exists(filepath): try: with open(filepath, "r", encoding="utf-8") as f: papers[filename] = f.read() except Exception as e: papers[filename] = f"Error loading paper: {str(e)}" else: papers[filename] = f"Paper not found: {filename}" return papers def extract_conclusion_from_paper(content: str) -> str: """Extract the conclusion section from a paper's content.""" conclusion_patterns = [ "conclusion and future works", "conclusion and future work", "conclusions", "conclusion", "summary and conclusions", "discussion and conclusions" ] lines = content.split('\n') conclusion_start = -1 for i, line in enumerate(lines): line_lower = line.lower().strip() if any(pattern in line_lower for pattern in conclusion_patterns): if (line.isupper() or line.strip().endswith(':') or len(line.strip()) < 100 or line.strip().startswith('Conclusion')): conclusion_start = i break if conclusion_start != -1: conclusion_lines = [] for line in lines[conclusion_start:]: line_stripped = line.strip() if (line_stripped.lower().startswith('acknowledgments') or line_stripped.lower().startswith('references') or line_stripped.startswith('--- Page')): break conclusion_lines.append(line) return '\n'.join(conclusion_lines) # Fallback: return the last 1000 characters return content[-1000:] if len(content) > 1000 else content def truncate_conversation_history(messages: list, max_tokens: int = 8000) -> list: """Truncate conversation history to stay within token limits.""" if len(messages) <= 3: return messages system_message = messages[0] conversation_messages = messages[1:] while len(conversation_messages) > 6: conversation_messages = conversation_messages[2:] return [system_message] + conversation_messages def respond( message: str, history: list[tuple[str, str]], model_name: str, max_tokens: int, temperature: float, top_p: float, ) -> Iterator[str]: """ Generate a response using OpenAI's models with function calling. """ if not client: yield "❌ Error: OpenAI API key not configured." return if not message.strip(): yield "Please enter a message to start the conversation." return # Get relevant full paper content based on user query relevant_papers_content = get_relevant_papers_content(message) # Check if user is asking for a specific paper (e.g., "show me the full paper about pigs") specific_paper_content = "" conclusion_content = "" paper_summary_content = "" if any(keyword in message.lower() for keyword in ["full paper", "complete paper", "entire paper", "show me the paper", "read the paper", "summarize", "summary"]): # Try to find specific paper content for filename, content in PAPER_TEXTS.items(): title = filename[:-4] if filename.endswith('.txt') else filename if any(term in title.lower() for term in message.lower().split()): if any(keyword in message.lower() for keyword in ["summarize", "summary"]): paper_summary_content = get_paper_summary(title) else: specific_paper_content = get_full_paper_content(title) break # Check if user is asking for conclusions specifically if any(keyword in message.lower() for keyword in ["conclusion", "conclusions", "what's the conclusion", "what is the conclusion"]): for filename, content in PAPER_TEXTS.items(): title = filename[:-4] if filename.endswith('.txt') else filename if any(term in title.lower() for term in message.lower().split()): conclusion_content = extract_conclusion_from_paper(content) break # Initialize messages with a comprehensive system prompt system_prompt = f"""You are an AI chatbot designed to help users explore and analyze AI research papers. You have access to: 1. An abstracts database with summaries of research papers 2. Full paper texts for detailed analysis 3. A tool to fetch additional paper content when needed ABSTRACTS DATABASE: {ABSTRACTS_CONTENT} RELEVANT PAPERS CONTENT: {chr(10).join([f"Paper: {filename}{chr(10)}Content: {content[:3000]}..." for filename, content, score in relevant_papers_content])} SPECIFIC PAPER CONTENT: {specific_paper_content if specific_paper_content else "None"} CONCLUSION CONTENT: {conclusion_content if conclusion_content else "None"} PAPER SUMMARY: {paper_summary_content if paper_summary_content else "None"} INSTRUCTIONS: - Use the abstracts for general questions and overview - Use full paper content when users ask for specific details, conclusions, or complete papers - Use the fetch_papers tool when you need additional paper content - Provide accurate, detailed responses based on the actual paper content - When referencing papers, use their actual titles from the filenames - Prioritize full paper content over abstracts when available""" messages = [{"role": "system", "content": system_prompt}] # Add conversation history for user_msg, assistant_msg in history: if user_msg and user_msg.strip(): messages.append({"role": "user", "content": user_msg.strip()}) if assistant_msg and assistant_msg.strip(): messages.append({"role": "assistant", "content": assistant_msg.strip()}) # Add current user message messages.append({"role": "user", "content": message.strip()}) # Truncate if needed messages = truncate_conversation_history(messages) try: model = AVAILABLE_MODELS.get(model_name, "gpt-4o-mini") # Initial response with tool support response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, tools=[FETCH_PAPERS_TOOL], tool_choice="auto", stream=True ) # Collect the response and handle tool calls full_response = "" tool_calls = [] current_tool_call = None for chunk in response: if hasattr(chunk.choices[0], 'delta'): delta = chunk.choices[0].delta # Handle regular content if delta.content is not None: full_response += delta.content yield full_response # Handle tool calls if delta.tool_calls: for tool_call_chunk in delta.tool_calls: if tool_call_chunk.id: # New tool call if current_tool_call: tool_calls.append(current_tool_call) current_tool_call = { "id": tool_call_chunk.id, "type": "function", "function": { "name": tool_call_chunk.function.name if tool_call_chunk.function else "", "arguments": "" } } if current_tool_call and tool_call_chunk.function: if tool_call_chunk.function.arguments: current_tool_call["function"]["arguments"] += tool_call_chunk.function.arguments # Add final tool call if exists if current_tool_call: tool_calls.append(current_tool_call) # Process tool calls if any if tool_calls: # Add the assistant's message with tool calls messages.append({ "role": "assistant", "content": full_response if full_response else None, "tool_calls": tool_calls }) # Execute tool calls for tool_call in tool_calls: function_name = tool_call["function"]["name"] if function_name == "fetch_papers": try: # Parse arguments arguments = json.loads(tool_call["function"]["arguments"]) filenames = arguments.get("filenames", []) # Fetch papers papers_content = fetch_papers(filenames) # Add tool response to messages tool_response = { "role": "tool", "tool_call_id": tool_call["id"], "content": json.dumps(papers_content) } messages.append(tool_response) except Exception as e: tool_response = { "role": "tool", "tool_call_id": tool_call["id"], "content": f"Error: {str(e)}" } messages.append(tool_response) # Get final response with tool results final_response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=True ) # Stream the final response final_text = "" for chunk in final_response: if hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta.content is not None: final_text += chunk.choices[0].delta.content yield full_response + "\n\n" + final_text if full_response else final_text except Exception as e: error_message = f"Error: {str(e)}" if "api_key" in str(e).lower(): error_message = "Error: Invalid or missing OpenAI API key." elif "quota" in str(e).lower(): error_message = "Error: API quota exceeded." elif "rate" in str(e).lower(): error_message = "Error: Rate limit exceeded." yield error_message def chat_fn(message, history, model_name, max_tokens, temperature, top_p): """Handle the entire chat interaction.""" if not message.strip(): return history history.append([message, ""]) for response in respond(message, history[:-1], model_name, max_tokens, temperature, top_p): history[-1][1] = response yield history def clear_history() -> tuple: """Clear the conversation history.""" return [], "" # Create the Gradio interface with gr.Blocks( title="📚 AI Research Paper Chatbot", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; margin: auto !important; } """ ) as demo: gr.Markdown( """ # 📚 AI Research Paper Chatbot Chat with an AI assistant that can intelligently retrieve and analyze research papers. """ ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( height=500, show_label=False, container=True, bubble_full_width=False ) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here...", show_label=False, container=False, scale=9 ) submit_btn = gr.Button("Send", variant="primary", scale=1) clear_btn = gr.Button("Clear", variant="secondary", scale=1) with gr.Column(scale=1): gr.Markdown("### ⚙️ Settings") model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="GPT-4o", label="Model", info="Select the AI model to use" ) max_tokens_slider = gr.Slider( minimum=1, maximum=4096, value=1024, step=1, label="Max Tokens", info="Maximum response length" ) temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Creativity level" ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p", info="Response diversity" ) gr.Markdown("### 💡 Examples") example_btn1 = gr.Button("What papers discuss AI's impact on employment?", size="sm") example_btn2 = gr.Button("Show me the full paper about AI companions", size="sm") example_btn3 = gr.Button("Compare findings on AI in education", size="sm") # Event handlers msg.submit( chat_fn, [msg, chatbot, model_dropdown, max_tokens_slider, temperature_slider, top_p_slider], [chatbot], show_progress=True ).then( lambda: "", outputs=[msg] ) submit_btn.click( chat_fn, [msg, chatbot, model_dropdown, max_tokens_slider, temperature_slider, top_p_slider], [chatbot], show_progress=True ).then( lambda: "", outputs=[msg] ) clear_btn.click(clear_history, outputs=[chatbot, msg]) # Example handlers example_btn1.click(lambda: "What papers discuss AI's impact on employment?", outputs=msg) example_btn2.click(lambda: "Show me the full paper about AI companions", outputs=msg) example_btn3.click(lambda: "Compare findings on AI in education", outputs=msg) if __name__ == "__main__": if not os.getenv("OPENAI_API_KEY"): print("⚠️ Warning: OPENAI_API_KEY environment variable not set!") demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, quiet=False )