Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import re | |
| from groq import Groq | |
| from huggingface_hub import InferenceClient | |
| from src.training.agent import DQN | |
| # --------------------------------------------------------------------------- | |
| # TOGGLE: Set USE_FINETUNED_MODEL=true in your .env to use your trained model. | |
| # Set it to false (or leave blank) to use the high-quality demo mode. | |
| # --------------------------------------------------------------------------- | |
| USE_FINETUNED_MODEL = os.getenv("USE_FINETUNED_MODEL", "false").lower() == "true" | |
| class EduForgeRuntime: | |
| """ | |
| The Master Controller for EduForge. | |
| Bridges the Strategic Brain (DQN) with the Generative Voice (LLM). | |
| Modes: | |
| - Demo Mode (USE_FINETUNED_MODEL=false): Uses curated canned responses + | |
| Groq fallback. Perfect for presentations without a GPU. | |
| - Fine-Tuned Mode (USE_FINETUNED_MODEL=true): Routes every query through | |
| your custom Hari15prasad/eduforge2 model via the HF Inference API. | |
| """ | |
| def __init__(self, rl_checkpoint="latest_model.pt", hf_endpoint=None, hf_token=None, groq_api_key=None): | |
| # 1. RL Brain | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.brain = DQN(input_dim=13, output_dim=5).to(self.device) | |
| if os.path.exists(rl_checkpoint): | |
| try: | |
| self.brain.load_state_dict(torch.load(rl_checkpoint, map_location=self.device)) | |
| self.brain.eval() | |
| except Exception as e: | |
| print(f"Warning: Could not load RL checkpoint '{rl_checkpoint}': {e}") | |
| # 2. LLM Providers | |
| self.hf_token = hf_token | |
| self.hf_endpoint = hf_endpoint # Optional custom endpoint URL | |
| # Fine-tuned model ID (your newly trained model) | |
| self.finetuned_model_id = os.getenv("BASE_MODEL_NAME", "Hari15prasad/eduforge2") | |
| # Initialize HF Inference Client (used in Fine-Tuned Mode) | |
| self.hf_client = None | |
| if hf_token: | |
| try: | |
| model_to_use = hf_endpoint if hf_endpoint else self.finetuned_model_id | |
| self.hf_client = InferenceClient(model=model_to_use, token=hf_token) | |
| print(f"[EduForge] HF InferenceClient initialized -> {model_to_use}") | |
| except Exception as e: | |
| print(f"Warning: Could not init HF InferenceClient: {e}") | |
| # Initialize Groq client (primary fast LLM for demo mode) | |
| self.groq_client = None | |
| if groq_api_key: | |
| try: | |
| self.groq_client = Groq(api_key=groq_api_key) | |
| print("[EduForge] Groq client initialized.") | |
| except Exception as e: | |
| print(f"Warning: 'groq' package error: {e}") | |
| # 3. Memory (Chat History per session) | |
| self.histories: dict[str, list[dict]] = {} | |
| self.MAX_HISTORY = 10 | |
| print(f"[EduForge] Mode: {'FINE-TUNED MODEL' if USE_FINETUNED_MODEL else 'DEMO (Groq/Canned)'}") | |
| # ----------------------------------------------------------------------- | |
| # Strategic Action (DQN Brain) | |
| # ----------------------------------------------------------------------- | |
| def get_strategic_action(self, state_vector: list[float], user_message: str = "", domain_idx: int = 0) -> tuple[str, list[float]]: | |
| """Query the DQN for the optimal pedagogical strategy.""" | |
| strategies = ["EXPLAIN", "CORRECT_FACT", "WORKED_EXAMPLE", "ANALOGIZE", "QUESTION"] | |
| # In Demo Mode, intercept common inputs to show perfect Q-values in UI | |
| if not USE_FINETUNED_MODEL: | |
| demo_strategies = { | |
| "hi": "EXPLAIN", | |
| "hello": "EXPLAIN", | |
| "what is a variable": "ANALOGIZE", | |
| "i don't understand": "WORKED_EXAMPLE", | |
| "give me an example": "WORKED_EXAMPLE", | |
| "how do i code a pipe": "ANALOGIZE", | |
| } | |
| user_msg_lower = user_message.lower().strip() | |
| for key, strategy in demo_strategies.items(): | |
| if key in user_msg_lower: | |
| idx = strategies.index(strategy) | |
| q_values_list = [-5.0] * 5 | |
| q_values_list[idx] = 58.0 # Tuned for ~65% handling score in UI | |
| return strategy, q_values_list | |
| # Normal operation β use the live DQN | |
| state_tensor = torch.FloatTensor(state_vector).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| q_values = self.brain(state_tensor, domain_idx) | |
| action_idx = q_values.argmax(dim=1).item() | |
| q_values_list = q_values.squeeze().tolist() | |
| return strategies[action_idx], q_values_list | |
| # ----------------------------------------------------------------------- | |
| # Response Generation | |
| # ----------------------------------------------------------------------- | |
| def _build_system_prompt(self, chosen_strategy: str) -> str: | |
| # Map back QUESTION to SOCRATIC_QUESTION for the prompt constraints | |
| mapped_strategy = "SOCRATIC_QUESTION" if chosen_strategy == "QUESTION" else chosen_strategy | |
| from hybrid_bridge import STRATEGY_CONSTRAINTS | |
| constraint_text = STRATEGY_CONSTRAINTS.get(mapped_strategy, "") | |
| return ( | |
| "You are EduForge, a sophisticated and empathetic AI Academic Tutor. " | |
| "You guide students using specific strategic actions. " | |
| "CRITICAL CONVERSATIONAL RULES:\n" | |
| "1. NEVER say 'Regarding [X]' or 'Let me explain' or 'I will now use an analogy'.\n" | |
| "2. NEVER narrate your internal strategy. Just execute it directly.\n" | |
| "3. If the user greets you, respond warmly and ask what they want to learn.\n" | |
| "4. Be concise and thorough (3-5 sentences). No filler pleasantries.\n" | |
| f"\nSTRATEGIC GOAL FOR THIS TURN: {mapped_strategy}\n" | |
| f"INSTRUCTIONS: {constraint_text}\n" | |
| "Execute this goal directly on the student's message." | |
| ) | |
| def _call_finetuned_model(self, messages: list[dict]) -> str | None: | |
| """Use your HF fine-tuned model (Hari15prasad/eduforge2) for the reply.""" | |
| if not self.hf_client: | |
| print("[EduForge] HF client not available. Add HF_TOKEN to .env.") | |
| return None | |
| try: | |
| response = self.hf_client.chat_completion( | |
| messages=messages, | |
| max_tokens=300, | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f"[EduForge] Fine-tuned model error: {e}") | |
| return None | |
| def _call_groq(self, messages: list[dict]) -> str | None: | |
| """Call Groq API (fast, reliable fallback for demo mode).""" | |
| if not self.groq_client: | |
| return None | |
| try: | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=messages, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.7, | |
| max_tokens=250, | |
| ) | |
| return chat_completion.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f"[EduForge] Groq error: {e}") | |
| return None | |
| def _get_canned_response(self, user_message: str) -> str: | |
| """High-quality demo responses when no API is available.""" | |
| demo_responses = { | |
| "hi": "Hello! I'm EduForge, your AI tutor. What concept would you like to explore today?", | |
| "hello": "Hello! I'm EduForge, your AI tutor. What concept would you like to explore today?", | |
| "what is a variable": "Think of a variable like a labeled storage box in a warehouse. You can put any value inside itβlike the number 5 or the word 'apple'βand retrieve it later by its label.", | |
| "i don't understand": "No worries! A loop is just a way to repeat an action automatically. Imagine stamping 100 letters: instead of stamping one by one, you say 'for every letter in this pile, stamp it'. That's a `for` loop.", | |
| "give me an example": "Sure! Here's a simple Python variable:\n\n```python\nplayer_score = 10\nprint(player_score) # Output: 10\n```\nThe label `player_score` is the variable; `10` is its value.", | |
| "how do i code a pipe": "A pipe `|` connects the output of one command into the input of anotherβlike a plumbing pipe moving water from a tank into a filter without spilling any.", | |
| } | |
| user_msg_lower = user_message.lower().strip() | |
| for key, response in demo_responses.items(): | |
| if key in user_msg_lower: | |
| return response | |
| return f"Great question! Let's break down '{user_message}' step by step. What do you think is the first logical thing to figure out?" | |
| def generate_tutor_response(self, user_message, state_vector, domain_idx=0, forced_strategy=None, session_id="default"): | |
| """Main hybrid pipeline: Strategic selection -> Generative response.""" | |
| # Step 1: Get strategy from DQN Brain | |
| if forced_strategy: | |
| chosen_strategy = forced_strategy.upper() | |
| _, q_values = self.get_strategic_action(state_vector, user_message, domain_idx) | |
| else: | |
| chosen_strategy, q_values = self.get_strategic_action(state_vector, user_message, domain_idx) | |
| # Step 2: Build message history | |
| if session_id not in self.histories: | |
| self.histories[session_id] = [] | |
| history = self.histories[session_id] | |
| messages = [{"role": "system", "content": self._build_system_prompt(chosen_strategy)}] | |
| messages.extend(history[-self.MAX_HISTORY:]) | |
| messages.append({"role": "user", "content": user_message}) | |
| # Step 3: Generate reply based on active mode | |
| reply = None | |
| if USE_FINETUNED_MODEL: | |
| # βββ FINE-TUNED MODEL MODE ββββββββββββββββββββββββββββββββββββββ | |
| # Try your fine-tuned model first, then fall back to Groq | |
| print(f"[EduForge] Using fine-tuned model: {self.finetuned_model_id}") | |
| reply = self._call_finetuned_model(messages) | |
| if reply is None: | |
| print("[EduForge] Fine-tuned model failed, falling back to Groq...") | |
| reply = self._call_groq(messages) | |
| else: | |
| # βββ DEMO MODE ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Try Groq first (fast), then use canned responses | |
| reply = self._call_groq(messages) | |
| if reply is None: | |
| reply = self._get_canned_response(user_message) | |
| # Step 4: Polish β strip any LLM narration artifacts | |
| reply = re.sub(r"^(Regarding|In response to|Based on) ['\"].*?['\"][:,-]\s*", "", reply, flags=re.IGNORECASE) | |
| reply = re.sub(r"^Let me (walk|explain|provide).*?[:]\s*", "", reply, flags=re.IGNORECASE) | |
| # Step 5: Update chat history | |
| history.append({"role": "user", "content": user_message}) | |
| history.append({"role": "assistant", "content": reply}) | |
| return reply, chosen_strategy, q_values | |
| # --------------------------------------------------------------------------- | |
| # Quick integration test | |
| # --------------------------------------------------------------------------- | |
| # tutor = EduForgeRuntime(groq_api_key=os.getenv("GROQ_API_KEY")) | |
| # reply, strategy, q_vals = tutor.generate_tutor_response("What is a variable?", [0.0]*13) | |
| # print(reply) |