Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import wikipedia | |
| import PyPDF2 | |
| import threading | |
| import requests | |
| from bs4 import BeautifulSoup | |
| from datetime import datetime | |
| import json | |
| # Load the model and tokenizer | |
| model_name = "Qwen/Qwen3-0.6B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # System prompt for safe responses | |
| SYSTEM_PROMPT = """You are a helpful, harmless, and honest AI assistant. | |
| - Provide accurate and factual information | |
| - Be respectful and avoid harmful, unethical, or offensive content | |
| - Admit when you don't know something | |
| - Stay on topic and provide clear, concise answers | |
| """ | |
| # Global variables for RAG | |
| rag_content = "" | |
| rag_filename = "" | |
| # Function to generate response from model | |
| def generate_response(prompt, max_length=512): | |
| full_prompt = SYSTEM_PROMPT + "\n\n" + prompt | |
| inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=2048).to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from response | |
| if full_prompt in response: | |
| response = response.replace(full_prompt, "").strip() | |
| return response | |
| # Function to generate search queries | |
| def generate_search_queries(user_query, num_queries=5): | |
| prompt = f"""Generate {num_queries} different search queries to find comprehensive information about: "{user_query}" | |
| The queries should cover different aspects and perspectives. List only the queries, one per line, without numbering. | |
| Queries:""" | |
| response = generate_response(prompt, max_length=256) | |
| # Parse the generated queries | |
| queries = [q.strip() for q in response.split('\n') if q.strip() and len(q.strip()) > 5] | |
| # If model didn't generate enough queries, add variations | |
| if len(queries) < num_queries: | |
| queries.append(user_query) | |
| queries.append(f"{user_query} latest news") | |
| queries.append(f"{user_query} {datetime.now().year}") | |
| queries.append(f"recent {user_query}") | |
| queries.append(f"{user_query} updates") | |
| return queries[:num_queries] | |
| # Enhanced Wikipedia search with multiple queries | |
| def enhanced_wiki_search(user_query): | |
| search_results = [] | |
| # Generate multiple search queries | |
| queries = generate_search_queries(user_query, num_queries=5) | |
| print(f"🔍 Generated search queries: {queries}") | |
| for query in queries: | |
| try: | |
| # Try to get Wikipedia summary | |
| summary = wikipedia.summary(query, sentences=3, auto_suggest=True) | |
| search_results.append({ | |
| 'query': query, | |
| 'source': 'Wikipedia', | |
| 'content': summary | |
| }) | |
| except wikipedia.exceptions.DisambiguationError as e: | |
| # If disambiguation, try first option | |
| try: | |
| summary = wikipedia.summary(e.options[0], sentences=3) | |
| search_results.append({ | |
| 'query': query, | |
| 'source': 'Wikipedia', | |
| 'content': summary | |
| }) | |
| except: | |
| pass | |
| except wikipedia.exceptions.PageError: | |
| # Try searching for the query | |
| try: | |
| search_list = wikipedia.search(query, results=3) | |
| if search_list: | |
| summary = wikipedia.summary(search_list[0], sentences=3) | |
| search_results.append({ | |
| 'query': query, | |
| 'source': 'Wikipedia', | |
| 'content': summary | |
| }) | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"Error with query '{query}': {str(e)}") | |
| continue | |
| return search_results | |
| # Function to aggregate and understand search results | |
| def aggregate_search_results(search_results, user_query): | |
| if not search_results: | |
| return "No search results found. Please try a different query." | |
| # Combine all search results | |
| combined_info = "\n\n".join([ | |
| f"Source: {result['source']}\nQuery: {result['query']}\nInformation: {result['content']}" | |
| for result in search_results | |
| ]) | |
| # Generate comprehensive response | |
| prompt = f"""Based on the following search results, provide a comprehensive and well-structured answer to the user's question: "{user_query}" | |
| Search Results: | |
| {combined_info} | |
| Instructions: | |
| - Synthesize information from all sources | |
| - Provide accurate and up-to-date information | |
| - If there are conflicting information, mention it | |
| - Structure your response clearly | |
| - Include relevant details and context | |
| Comprehensive Answer:""" | |
| response = generate_response(prompt, max_length=1024) | |
| return response | |
| # Function to extract text from PDF or TXT | |
| def extract_text(file): | |
| global rag_filename | |
| try: | |
| if file.name.endswith(".pdf"): | |
| rag_filename = file.name | |
| pdf_reader = PyPDF2.PdfReader(file.name) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| return text if text else "Could not extract text from PDF." | |
| elif file.name.endswith(".txt"): | |
| rag_filename = file.name | |
| with open(file.name, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| else: | |
| return "Unsupported file type. Please upload PDF or TXT files." | |
| except Exception as e: | |
| return f"Error reading file: {str(e)}" | |
| # Main chat function with history | |
| def chat(message, history, mode, file=None): | |
| global rag_content, rag_filename | |
| if not message.strip(): | |
| return history, "" | |
| # Handle file upload for RAG | |
| if file: | |
| extracted = extract_text(file) | |
| if extracted.startswith("Error") or extracted.startswith("Unsupported") or extracted.startswith("Could not"): | |
| history.append((message, f"❌ {extracted}")) | |
| return history, "" | |
| rag_content = extracted | |
| history.append((None, f"✓ File uploaded: {rag_filename} ({len(rag_content)} characters)")) | |
| # Generate response based on mode | |
| if mode == "Web search": | |
| # Show searching indicator | |
| history.append((message, "🔍 Searching and analyzing information...")) | |
| yield history, "" | |
| # Generate multiple search queries | |
| search_queries = generate_search_queries(message, num_queries=5) | |
| # Perform searches | |
| search_results = enhanced_wiki_search(message) | |
| # Aggregate and generate response | |
| if search_results: | |
| response = "📊 *Search Queries Generated:*\n" | |
| response += "\n".join([f"- {q}" for q in search_queries]) | |
| response += f"\n\n✅ *Found {len(search_results)} relevant sources*\n\n" | |
| # Generate comprehensive answer | |
| final_answer = aggregate_search_results(search_results, message) | |
| response += "📝 *Comprehensive Answer:*\n" + final_answer | |
| else: | |
| response = "❌ Could not find relevant information. Please try rephrasing your query." | |
| # Update the last message with final response | |
| history[-1] = (message, response) | |
| elif mode == "Think": | |
| think_prompt = f"Think step by step about the following question: {message}\n\nProvide your reasoning process:" | |
| thoughts = generate_response(think_prompt, max_length=512) | |
| final_prompt = f"Based on this reasoning:\n{thoughts}\n\nNow provide a final answer to: {message}" | |
| final_response = generate_response(final_prompt, max_length=512) | |
| response = f"🤔 *Thinking Process:\n{thoughts}\n\n💡 **Final Answer:*\n{final_response}" | |
| history.append((message, response)) | |
| elif mode == "No think": | |
| prompt = f"Answer the following question directly and concisely:\n{message}\n\nAnswer:" | |
| response = generate_response(prompt, max_length=512) | |
| history.append((message, response)) | |
| elif mode == "RAG": | |
| if not rag_content: | |
| history.append((message, "⚠ Please upload a PDF or TXT file first for RAG mode.")) | |
| return history, "" | |
| chunk_size = 1500 | |
| prompt = f"Document content:\n{rag_content[:chunk_size]}\n\nUser question: {message}\n\nAnswer based strictly on the document content above:" | |
| response = generate_response(prompt, max_length=768) | |
| history.append((message, response)) | |
| else: | |
| response = "Invalid mode selected." | |
| history.append((message, response)) | |
| yield history, "" | |
| # Function for parallel chat | |
| def parallel_chat(q1, q2, q3, q4, mode, file=None): | |
| global rag_content | |
| # Handle file upload for RAG | |
| if file and mode == "RAG": | |
| extracted = extract_text(file) | |
| if not (extracted.startswith("Error") or extracted.startswith("Unsupported")): | |
| rag_content = extracted | |
| responses = [None, None, None, None] | |
| questions = [q1, q2, q3, q4] | |
| def process(i): | |
| if questions[i] and questions[i].strip(): | |
| temp_history = [] | |
| # Use the generator and get final result | |
| for result, _ in chat(questions[i], temp_history, mode): | |
| pass | |
| if result: | |
| responses[i] = result[-1][1] | |
| threads = [] | |
| for i in range(4): | |
| if questions[i] and questions[i].strip(): | |
| t = threading.Thread(target=process, args=(i,)) | |
| t.start() | |
| threads.append(t) | |
| for t in threads: | |
| t.join() | |
| return (responses[0] or "No question provided", | |
| responses[1] or "No question provided", | |
| responses[2] or "No question provided", | |
| responses[3] or "No question provided") | |
| # Custom CSS for better UI | |
| custom_css = """ | |
| #chatbot { | |
| height: 600px; | |
| overflow-y: auto; | |
| } | |
| .message { | |
| padding: 10px; | |
| margin: 5px; | |
| border-radius: 8px; | |
| } | |
| """ | |
| # Gradio interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🤖 AI Chatbot with Advanced Web Search") | |
| gr.Markdown("Choose your preferred mode and start chatting! Web search now generates multiple queries for comprehensive results.") | |
| with gr.Tab("💬 Chat"): | |
| with gr.Row(): | |
| mode = gr.Dropdown( | |
| choices=["No think", "Think", "Web search", "RAG"], | |
| label="Select Mode", | |
| value="No think", | |
| info="Web search uses 5 different queries for comprehensive results" | |
| ) | |
| file = gr.File( | |
| label="Upload File (PDF/TXT)", | |
| file_types=[".pdf", ".txt"], | |
| type="filepath" | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| elem_id="chatbot", | |
| height=500, | |
| show_label=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send 📤", scale=1, variant="primary") | |
| clear_btn = gr.Button("Clear History 🗑") | |
| # Chat functionality | |
| send_btn.click( | |
| chat, | |
| inputs=[input_text, chatbot, mode, file], | |
| outputs=[chatbot, input_text] | |
| ) | |
| input_text.submit( | |
| chat, | |
| inputs=[input_text, chatbot, mode, file], | |
| outputs=[chatbot, input_text] | |
| ) | |
| clear_btn.click(lambda: [], None, chatbot) | |
| with gr.Tab("⚡ Parallel Chat"): | |
| gr.Markdown("### Ask up to 4 questions simultaneously!") | |
| mode_parallel = gr.Dropdown( | |
| choices=["No think", "Think", "Web search", "RAG"], | |
| label="Select Mode", | |
| value="No think" | |
| ) | |
| file_parallel = gr.File( | |
| label="Upload File for RAG (PDF/TXT)", | |
| file_types=[".pdf", ".txt"], | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| q1 = gr.Textbox(label="Question 1", lines=2) | |
| q2 = gr.Textbox(label="Question 2", lines=2) | |
| with gr.Column(): | |
| q3 = gr.Textbox(label="Question 3", lines=2) | |
| q4 = gr.Textbox(label="Question 4", lines=2) | |
| btn_parallel = gr.Button("Submit All Questions 🚀", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| r1 = gr.Textbox(label="Response 1", lines=8, max_lines=20) | |
| r2 = gr.Textbox(label="Response 2", lines=8, max_lines=20) | |
| with gr.Column(): | |
| r3 = gr.Textbox(label="Response 3", lines=8, max_lines=20) | |
| r4 = gr.Textbox(label="Response 4", lines=8, max_lines=20) | |
| btn_parallel.click( | |
| parallel_chat, | |
| inputs=[q1, q2, q3, q4, mode_parallel, file_parallel], | |
| outputs=[r1, r2, r3, r4] | |
| ) | |
| with gr.Tab("ℹ About"): | |
| gr.Markdown(""" | |
| ## Features: | |
| - *No think*: Direct, concise answers | |
| - *Think*: Step-by-step reasoning process | |
| - *Web search: 🔥 **NEW!* Generates 5 different search queries and aggregates results for comprehensive answers | |
| - *RAG*: Answer questions based on uploaded documents | |
| ## Enhanced Web Search: | |
| - Automatically generates 5 diverse search queries | |
| - Searches multiple sources simultaneously | |
| - Aggregates and synthesizes information | |
| - Provides comprehensive, up-to-date answers | |
| ## Tips: | |
| - Upload PDF or TXT files for RAG mode | |
| - Use parallel chat for comparing different questions | |
| - Clear history to start fresh conversations | |
| - Web search is best for current events and factual queries | |
| ## Safety: | |
| This chatbot includes a system prompt for safe, helpful, and honest responses. | |
| """) | |
| demo.launch(share=True) |