Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from typing import List, Tuple | |
| import logging | |
| from collections import deque | |
| import re | |
| import os | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize the InferenceClient with API token | |
| try: | |
| client = InferenceClient( | |
| model="meta-llama/Llama-2-7b-chat-hf", # Updated to the requested model | |
| token=os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| logger.info("Successfully initialized InferenceClient") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize InferenceClient: {str(e)}") | |
| raise | |
| # Memory storage for learning from past queries | |
| MEMORY = deque(maxlen=100) # Store up to 100 query-response pairs | |
| def add_to_memory(query: str, response: str): | |
| """Add a query-response pair to memory.""" | |
| MEMORY.append({"query": query, "response": response}) | |
| logger.info("Added query-response pair to memory") | |
| def find_relevant_context(query: str, max_contexts: int = 2) -> str: | |
| """Retrieve relevant past queries and responses based on simple keyword matching.""" | |
| query_words = set(re.findall(r'\w+', query.lower())) | |
| relevant = [] | |
| for mem in MEMORY: | |
| mem_words = set(re.findall(r'\w+', mem["query"].lower())) | |
| overlap = len(query_words & mem_words) / max(len(query_words), 1) | |
| if overlap > 0.3: # Threshold for relevance | |
| relevant.append(mem) | |
| if len(relevant) >= max_contexts: | |
| break | |
| if relevant: | |
| context = "\n".join( | |
| [f"Past Query: {mem['query']}\nPast Response: {mem['response']}" for mem in relevant] | |
| ) | |
| return f"Relevant past interactions:\n{context}\n\n" | |
| return "" | |
| def respond( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> str: | |
| """ | |
| Generates an educational response using past interactions for context. | |
| Args: | |
| message (str): The student's input question or query. | |
| history (List[Tuple[str, str]]): Chat history with student and AI teacher messages. | |
| system_message (str): The system prompt defining the AI teacher's behavior. | |
| max_tokens (int): Maximum number of tokens to generate. | |
| temperature (float): Controls randomness in response generation. | |
| top_p (float): Controls diversity via nucleus sampling. | |
| Yields: | |
| str: The AI teacher's response, streamed token by token. | |
| """ | |
| # Validate input parameters | |
| if not message.strip(): | |
| raise ValueError("Input message cannot be empty") | |
| if max_tokens < 1 or max_tokens > 2048: | |
| raise ValueError("max_tokens must be between 1 and 2048") | |
| if temperature < 0.1 or temperature > 2.0: | |
| raise ValueError("temperature must be between 0.1 and 2.0") | |
| if top_p < 0.1 or top_p > 1.0: | |
| raise ValueError("top_p must be between 0.1 and 1.0") | |
| # Retrieve relevant past interactions | |
| context = find_relevant_context(message) | |
| # Construct the message history with memory context | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_message + "\n\nUse the following past interactions to inform your response if relevant:\n" + context, | |
| } | |
| ] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| try: | |
| stream = client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| for message in stream: | |
| token = message.choices[0].delta.content or "" | |
| response += token | |
| yield response | |
| # Store the query and final response in memory | |
| add_to_memory(message, response) | |
| except Exception as e: | |
| error_msg = f"Error during chat completion: {str(e)}" | |
| logger.error(error_msg) | |
| yield error_msg # Yield the error message to display in Gradio | |
| def main(): | |
| """ | |
| Sets up and launches the Gradio ChatInterface for the AI Teacher chatbot. | |
| """ | |
| default_system_message = ( | |
| "You are an AI Teacher, a knowledgeable and patient educator dedicated to helping students and learners. " | |
| "Your goal is to explain concepts clearly, provide step-by-step guidance, and encourage critical thinking. " | |
| "Adapt your explanations to the learner's level, ask follow-up questions to deepen understanding, and provide examples where helpful. " | |
| "Be supportive, professional, and engaging in all interactions." | |
| ) | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value=default_system_message, | |
| label="AI Teacher Prompt", | |
| lines=3, | |
| placeholder="Customize the AI Teacher's teaching style or instructions", | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=512, | |
| step=1, | |
| label="Maximum Response Length", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Response Creativity", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Response Diversity", | |
| ), | |
| ], | |
| title="AI Teacher: Your Study Companion", | |
| description=( | |
| "Welcome to AI Teacher, your personal guide for learning and studying! " | |
| "Ask questions about any subject, and I'll provide clear explanations, examples, and tips to help you succeed. " | |
| "Adjust the settings to customize how I respond to your questions." | |
| ), | |
| theme="soft", | |
| css=""" | |
| .gradio-container { max-width: 900px; margin: auto; padding: 20px; } | |
| .chatbot { border-radius: 12px; background-color: #f9fafb; } | |
| h1 { color: #2b6cb0; } | |
| .message { font-size: 16px; } | |
| """, | |
| ) | |
| try: | |
| logger.info("Launching Gradio interface for AI Teacher") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| except Exception as e: | |
| logger.error(f"Failed to launch Gradio interface: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |