Spaces:
Runtime error
Runtime error
| # β Ultra-lightweight Gradio chatbot with PPO reward model for Free Colab | |
| # Optimized to prevent RAM crashes with minimal memory footprint | |
| # FIXED: Improved PPO trigger logic to handle continuous training | |
| # NEW: Added voice input capability with speech-to-text | |
| ## FIXED: Removed problematic 'every' parameter from Gradio events | |
| import json | |
| import torch | |
| import faiss | |
| import numpy as np | |
| import re | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| import pandas as pd | |
| import os | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from datetime import datetime | |
| import gc | |
| import math | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| import pickle | |
| # NEW: Import for speech-to-text functionality (with error handling) | |
| try: | |
| import speech_recognition as sr | |
| SPEECH_AVAILABLE = True | |
| print("β Speech recognition available") | |
| except ImportError: | |
| print("β οΈ Speech recognition not available. Install with: pip install SpeechRecognition") | |
| SPEECH_AVAILABLE = False | |
| try: | |
| import librosa | |
| LIBROSA_AVAILABLE = True | |
| except ImportError: | |
| print("β οΈ Librosa not available. Advanced audio processing disabled.") | |
| LIBROSA_AVAILABLE = False | |
| # β Load model and tokenizer | |
| MODEL_PATH = "Shilpagotur/UNICEF_chatbot" | |
| DATA_PATH = "Shilpagotur/training_dataset_modified.json" | |
| RLHF_SAVE_PATH = "Shilpagotur/rlhf_models" # Path to save RLHF models and data | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=dtype).to(device) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # β Load RAG data | |
| with open(DATA_PATH, "r") as f: | |
| knowledge_base = json.load(f) | |
| retrieval_texts = [item["answer"] for item in knowledge_base] | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| retrieval_embeddings = embedding_model.encode(retrieval_texts, convert_to_numpy=True) | |
| # β Build FAISS index | |
| dimension = retrieval_embeddings.shape[1] | |
| faiss_index = faiss.IndexFlatL2(dimension) | |
| faiss_index.add(retrieval_embeddings) | |
| # NEW: Initialize speech recognition (if available) | |
| if SPEECH_AVAILABLE: | |
| recognizer = sr.Recognizer() | |
| # NEW: Speech-to-text function with error handling | |
| def speech_to_text(audio_file): | |
| """Convert speech audio to text using speech_recognition library""" | |
| if not SPEECH_AVAILABLE: | |
| return "Speech recognition not available. Please install SpeechRecognition library." | |
| if audio_file is None: | |
| return "" | |
| try: | |
| print("π€ Processing speech input...") | |
| # Create recognizer instance | |
| r = sr.Recognizer() | |
| # Load audio file | |
| with sr.AudioFile(audio_file) as source: | |
| # Adjust for ambient noise and record | |
| r.adjust_for_ambient_noise(source, duration=0.5) | |
| audio_data = r.record(source) | |
| # Recognize speech using Google Web Speech API (free tier) | |
| try: | |
| text = r.recognize_google(audio_data) | |
| print(f"ποΈ Recognized speech: '{text}'") | |
| return text | |
| except sr.UnknownValueError: | |
| print("β Could not understand the audio") | |
| return "Sorry, I couldn't understand what you said. Please try again or use text input." | |
| except sr.RequestError as e: | |
| print(f"β Error with speech recognition service: {e}") | |
| return "Speech recognition service is unavailable. Please use text input." | |
| except Exception as e: | |
| print(f"β Error processing audio: {e}") | |
| return "Error processing audio. Please try again or use text input." | |
| # NEW: Combined input processing function | |
| def process_voice_input(audio_file, text_input): | |
| """Process voice input and convert to text, fallback to text input if needed""" | |
| if audio_file is not None and SPEECH_AVAILABLE: | |
| # If voice input is provided, convert it to text | |
| voice_text = speech_to_text(audio_file) | |
| if voice_text and not voice_text.startswith("Sorry") and not voice_text.startswith("Speech") and not voice_text.startswith("Error"): | |
| return voice_text | |
| # Fallback to text input if voice failed or not provided | |
| return text_input if text_input else "" | |
| # β Helper functions (unchanged) | |
| def retrieve_context(query, top_k=3): | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
| D, I = faiss_index.search(query_embedding, top_k) | |
| return [retrieval_texts[i] for i in I[0]] | |
| def get_age_context(age): | |
| age = int(age) if age.isdigit() else 12 | |
| if age <= 7: | |
| return "You are a young child, so I will explain things simply, clearly, and empathetically." | |
| elif age <= 12: | |
| return "You are an older child, so I will explain with more detail while keeping it safe, empathetic, and understandable." | |
| elif age <= 17: | |
| return "You are a teenager, so I will provide more in-depth, more technical advice with caution, but in an empathetic way" | |
| else: | |
| return "The user is an adult. Use professional language and detailed safety advice." | |
| def get_few_shot_examples(age): | |
| age = int(age) if str(age).isdigit() else 12 | |
| if age <= 7: | |
| return """ | |
| Example 1: | |
| User: They said I can download points for free if I install a background app is it legit. | |
| Assistant: That could be a trick. Don't install anything someone tells you without asking a grown-up. Talk to parent or guardian | |
| """ | |
| elif age <= 12: | |
| return """ | |
| Example 1: | |
| User: They said I can download points for free if I install a background app is it legit. | |
| Assistant: Be careful! Sometimes people say things like free points to trick kids into installing bad apps. These apps can steal your stuff or break your device. Always ask your parents before downloading anything. If it seems too good to be true, it usually is. | |
| """ | |
| else: | |
| return """ | |
| Example 1: | |
| User: Someone blocked me online. Should I get revenge? | |
| Assistant: Offers like free points if you install an app are often scams or phishing attempts. The background app could secretly collect your data, show you unsafe ads, or even install malware without you knowing. Legit apps don't offer rewards in exchange for random downloads. Always check reviews, publisher info, and permissions. It's best to ignore such offers and report them if possible. Stay safe and don't trade your privacy for fake rewards. | |
| """ | |
| def is_followup_question(question): | |
| patterns = [ | |
| r'\b(what about|how about|what if|but what|also|and|additionally)\b', | |
| r'\b(can you explain|tell me more|elaborate|expand)\b', | |
| r'\b(why|how|when|where|who)\b.*\?', | |
| r'\b(that|this|it|they|them)\b', | |
| r'\b(previous|before|earlier|above)\b', | |
| r'^(yes|no|okay|ok|sure|right|exactly|correct|wrong)' | |
| ] | |
| return any(re.search(p, question.lower().strip()) for p in patterns) | |
| # β Ultra-Lightweight Reward Model using Scikit-learn (NO PYTORCH!) | |
| class SklearnRewardModel: | |
| """Ultra-lightweight reward model using scikit-learn instead of PyTorch""" | |
| def __init__(self): | |
| # Use TF-IDF for text features (much lighter than transformers) | |
| self.vectorizer = TfidfVectorizer( | |
| max_features=500, # Very small feature set | |
| stop_words='english', | |
| ngram_range=(1, 2), # Unigrams and bigrams | |
| max_df=0.95, | |
| min_df=2 | |
| ) | |
| # Use Logistic Regression (very lightweight) | |
| self.classifier = LogisticRegression( | |
| random_state=42, | |
| max_iter=100, | |
| C=1.0 | |
| ) | |
| self.is_trained = False | |
| self.training_scores = [] | |
| def preprocess_text(self, query, response): | |
| """Create features from query-response pair""" | |
| # Combine query and response | |
| combined_text = f"Q: {query[:200]} A: {response[:200]}" | |
| return combined_text | |
| def train(self, training_data): | |
| """Train the reward model on feedback data""" | |
| if len(training_data) < 3: | |
| return {'success': False, 'reason': 'Need at least 3 samples'} | |
| try: | |
| print(f"π Training lightweight reward model with {len(training_data)} samples...") | |
| # Prepare training data | |
| texts = [] | |
| labels = [] | |
| for item in training_data: | |
| text = self.preprocess_text(item['query'], item['response']) | |
| label = 1 if item['feedback'] == "π Like" else 0 | |
| texts.append(text) | |
| labels.append(label) | |
| # Check if we have both positive and negative examples | |
| if len(set(labels)) < 2: | |
| return {'success': False, 'reason': 'Need both positive and negative examples'} | |
| # Vectorize text | |
| X = self.vectorizer.fit_transform(texts) | |
| # Train classifier | |
| self.classifier.fit(X, labels) | |
| self.is_trained = True | |
| # Calculate training accuracy | |
| train_score = self.classifier.score(X, labels) | |
| self.training_scores.append(train_score) | |
| print(f"β Training completed! Accuracy: {train_score:.3f}") | |
| return { | |
| 'success': True, | |
| 'accuracy': train_score, | |
| 'samples_used': len(training_data) | |
| } | |
| except Exception as e: | |
| print(f"β Training error: {e}") | |
| return {'success': False, 'error': str(e)} | |
| def predict_reward(self, query, response): | |
| """Predict reward for query-response pair""" | |
| if not self.is_trained: | |
| return 0.5 # Neutral score if not trained | |
| try: | |
| text = self.preprocess_text(query, response) | |
| X = self.vectorizer.transform([text]) | |
| # Get probability of positive class | |
| prob = self.classifier.predict_proba(X)[0][1] | |
| return float(prob) | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return 0.5 | |
| def save_model(self, filepath): | |
| """Save the trained model""" | |
| try: | |
| model_data = { | |
| 'vectorizer': self.vectorizer, | |
| 'classifier': self.classifier, | |
| 'is_trained': self.is_trained, | |
| 'training_scores': self.training_scores | |
| } | |
| with open(filepath, 'wb') as f: | |
| pickle.dump(model_data, f) | |
| return True | |
| except Exception as e: | |
| print(f"Save error: {e}") | |
| return False | |
| def load_model(self, filepath): | |
| """Load a trained model""" | |
| try: | |
| if os.path.exists(filepath): | |
| with open(filepath, 'rb') as f: | |
| model_data = pickle.load(f) | |
| self.vectorizer = model_data['vectorizer'] | |
| self.classifier = model_data['classifier'] | |
| self.is_trained = model_data['is_trained'] | |
| self.training_scores = model_data.get('training_scores', []) | |
| print(f"β Reward model loaded! Trained: {self.is_trained}") | |
| return True | |
| except Exception as e: | |
| print(f"Load error: {e}") | |
| return False | |
| # β FIXED: Improved PPO-style Trainer with Better Trigger Logic | |
| class LightweightRLHFTrainer: | |
| """Ultra-lightweight RLHF trainer for free Colab with FIXED trigger logic""" | |
| def __init__(self, model_path, tokenizer, device='cuda'): | |
| self.device = device | |
| self.tokenizer = tokenizer | |
| # Keep reference to main model (but don't modify it heavily) | |
| self.main_model = model | |
| # Initialize lightweight reward model | |
| self.reward_model = SklearnRewardModel() | |
| # Minimal optimizer for main model (only used sparingly) | |
| self.main_optimizer = optim.AdamW( | |
| self.main_model.parameters(), | |
| lr=5e-6, # Very small learning rate | |
| weight_decay=0.01 | |
| ) | |
| # Training data storage | |
| self.feedback_buffer = [] | |
| self.training_history = [] | |
| self.ppo_training_triggered = False | |
| # π§ NEW: Track training state | |
| self.last_training_timestamp = "1970-01-01T00:00:00" # Start with epoch | |
| self.total_training_cycles = 0 | |
| self.new_feedback_threshold = 5 # Need 5 new feedbacks for retraining | |
| # Load any existing model | |
| reward_model_path = os.path.join(RLHF_SAVE_PATH, "lightweight_reward_model.pkl") | |
| self.reward_model.load_model(reward_model_path) | |
| # Load existing training state | |
| self.load_training_state() | |
| def load_training_state(self): | |
| """Load previous training state""" | |
| try: | |
| state_path = os.path.join(RLHF_SAVE_PATH, "training_state.json") | |
| if os.path.exists(state_path): | |
| with open(state_path, 'r') as f: | |
| state = json.load(f) | |
| self.last_training_timestamp = state.get('last_training_timestamp', "1970-01-01T00:00:00") | |
| self.total_training_cycles = state.get('total_training_cycles', 0) | |
| print(f"β Loaded training state: {self.total_training_cycles} cycles, last training: {self.last_training_timestamp}") | |
| except Exception as e: | |
| print(f"β οΈ Could not load training state: {e}") | |
| def save_training_state(self): | |
| """Save current training state""" | |
| try: | |
| os.makedirs(RLHF_SAVE_PATH, exist_ok=True) | |
| state_path = os.path.join(RLHF_SAVE_PATH, "training_state.json") | |
| state = { | |
| 'last_training_timestamp': self.last_training_timestamp, | |
| 'total_training_cycles': self.total_training_cycles, | |
| 'saved_at': datetime.now().isoformat() | |
| } | |
| with open(state_path, 'w') as f: | |
| json.dump(state, f, indent=2) | |
| except Exception as e: | |
| print(f"β οΈ Could not save training state: {e}") | |
| def count_new_feedback(self): | |
| """Count feedback received since last training""" | |
| new_count = 0 | |
| for feedback in self.feedback_buffer: | |
| feedback_time = feedback.get('timestamp', "1970-01-01T00:00:00") | |
| if feedback_time > self.last_training_timestamp: | |
| new_count += 1 | |
| return new_count | |
| def add_feedback(self, query, response, feedback, age=None): | |
| """Add feedback to buffer with IMPROVED trigger logic""" | |
| feedback_item = { | |
| 'query': query, | |
| 'response': response, | |
| 'feedback': feedback, | |
| 'age': age, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| self.feedback_buffer.append(feedback_item) | |
| # π§ IMPROVED: Better trigger logic | |
| buffer_size = len(self.feedback_buffer) | |
| new_feedback_count = self.count_new_feedback() | |
| # Trigger conditions: | |
| # 1. Have at least 10 total feedbacks AND | |
| # 2. Have at least 5 new feedbacks since last training AND | |
| # 3. Not currently training | |
| should_trigger = ( | |
| buffer_size >= 10 and | |
| new_feedback_count >= self.new_feedback_threshold and | |
| not self.ppo_training_triggered | |
| ) | |
| print(f"π Feedback stats: Total={buffer_size}, New={new_feedback_count}, Should trigger={should_trigger}") | |
| if should_trigger: | |
| try: | |
| print(f"\nπ Triggering PPO training cycle #{self.total_training_cycles + 1}...") | |
| print(f" - Total feedback: {buffer_size}") | |
| print(f" - New feedback since last training: {new_feedback_count}") | |
| self.ppo_training_triggered = True | |
| # Train reward model | |
| training_result = self.train_reward_model() | |
| if training_result.get('success', False): | |
| print("β Reward model training successful!") | |
| # Minimal main model update (very conservative) | |
| self.minimal_model_update() | |
| # Update training state | |
| self.last_training_timestamp = datetime.now().isoformat() | |
| self.total_training_cycles += 1 | |
| # Save progress | |
| self.save_progress() | |
| self.save_training_state() | |
| else: | |
| print("β οΈ Reward model training failed") | |
| self.ppo_training_triggered = False | |
| print(f"--- Training cycle #{self.total_training_cycles} complete ---\n") | |
| except Exception as e: | |
| print(f"β Training error: {e}") | |
| self.ppo_training_triggered = False | |
| def train_reward_model(self): | |
| """Train the lightweight reward model""" | |
| return self.reward_model.train(self.feedback_buffer) | |
| def minimal_model_update(self): | |
| """Very minimal update to main model to avoid OOM""" | |
| try: | |
| print("π Performing minimal model update...") | |
| # Only process the best examples (highest predicted rewards) | |
| positive_samples = [f for f in self.feedback_buffer if f['feedback'] == "π Like"] | |
| if not positive_samples: | |
| return | |
| # Score and select only the top examples | |
| scored_samples = [] | |
| for sample in positive_samples: | |
| reward = self.reward_model.predict_reward(sample['query'], sample['response']) | |
| if reward > 0.8: # Only very high-confidence examples | |
| scored_samples.append((sample, reward)) | |
| # Sort by reward and take only top 2 to avoid memory issues | |
| scored_samples.sort(key=lambda x: x[1], reverse=True) | |
| top_samples = scored_samples[:2] | |
| if not top_samples: | |
| print("No high-quality samples found for update") | |
| return | |
| # Very minimal training step | |
| self.main_model.train() | |
| for sample, reward in top_samples: | |
| try: | |
| # Create minimal prompt | |
| prompt = f"User: {sample['query'][:100]}\nAssistant: {sample['response'][:100]}" | |
| # Tokenize with small length | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=128, | |
| truncation=True, | |
| padding=True | |
| ).to(self.device) | |
| # Single forward pass | |
| outputs = self.main_model(**inputs, labels=inputs["input_ids"]) | |
| loss = outputs.loss | |
| if loss is not None and not torch.isnan(loss): | |
| # Scale by reward and make update very small | |
| scaled_loss = loss * reward * 0.1 # Very conservative scaling | |
| self.main_optimizer.zero_grad() | |
| scaled_loss.backward() | |
| # Very conservative gradient clipping | |
| torch.nn.utils.clip_grad_norm_(self.main_model.parameters(), max_norm=0.1) | |
| self.main_optimizer.step() | |
| # Immediate cleanup | |
| del inputs, outputs, loss | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"Error in sample update: {e}") | |
| continue | |
| print(f"β Updated model with {len(top_samples)} high-quality samples") | |
| except Exception as e: | |
| print(f"β Model update error: {e}") | |
| finally: | |
| # Cleanup | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Keep more feedback but limit to prevent memory issues | |
| if len(self.feedback_buffer) > 50: | |
| self.feedback_buffer = self.feedback_buffer[-50:] # Keep last 50 | |
| def predict_reward(self, query, response): | |
| """Predict reward using lightweight model""" | |
| return self.reward_model.predict_reward(query, response) | |
| def save_progress(self): | |
| """Save training progress""" | |
| try: | |
| os.makedirs(RLHF_SAVE_PATH, exist_ok=True) | |
| # Save reward model | |
| reward_model_path = os.path.join(RLHF_SAVE_PATH, "lightweight_reward_model.pkl") | |
| self.reward_model.save_model(reward_model_path) | |
| # Save feedback buffer | |
| feedback_path = os.path.join(RLHF_SAVE_PATH, "feedback_buffer.json") | |
| with open(feedback_path, 'w') as f: | |
| json.dump(self.feedback_buffer, f, indent=2) | |
| # Save training history | |
| history_path = os.path.join(RLHF_SAVE_PATH, "training_history.json") | |
| training_entry = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'feedback_count': len(self.feedback_buffer), | |
| 'reward_scores': self.reward_model.training_scores, | |
| 'training_cycle': self.total_training_cycles | |
| } | |
| self.training_history.append(training_entry) | |
| with open(history_path, 'w') as f: | |
| json.dump(self.training_history, f, indent=2) | |
| print("β Progress saved successfully") | |
| except Exception as e: | |
| print(f"β Save error: {e}") | |
| def load_progress(self): | |
| """Load saved progress""" | |
| try: | |
| # Load feedback buffer | |
| feedback_path = os.path.join(RLHF_SAVE_PATH, "feedback_buffer.json") | |
| if os.path.exists(feedback_path): | |
| with open(feedback_path, 'r') as f: | |
| self.feedback_buffer = json.load(f) | |
| print(f"β Loaded {len(self.feedback_buffer)} feedback entries") | |
| # Load training history | |
| history_path = os.path.join(RLHF_SAVE_PATH, "training_history.json") | |
| if os.path.exists(history_path): | |
| with open(history_path, 'r') as f: | |
| self.training_history = json.load(f) | |
| print(f"β Loaded {len(self.training_history)} training history entries") | |
| except Exception as e: | |
| print(f"β Load error: {e}") | |
| def get_training_stats(self): | |
| """Get comprehensive training statistics with FIXED logic""" | |
| positive_feedback = sum(1 for f in self.feedback_buffer if f['feedback'] == "π Like") | |
| negative_feedback = len(self.feedback_buffer) - positive_feedback | |
| new_feedback_count = self.count_new_feedback() | |
| stats = { | |
| 'total_feedback': len(self.feedback_buffer), | |
| 'positive_feedback': positive_feedback, | |
| 'negative_feedback': negative_feedback, | |
| 'satisfaction_rate': positive_feedback / len(self.feedback_buffer) if self.feedback_buffer else 0, | |
| 'new_feedback_since_training': new_feedback_count, | |
| 'ready_for_training': len(self.feedback_buffer) >= 10 and new_feedback_count >= self.new_feedback_threshold, | |
| 'reward_model_trained': self.reward_model.is_trained, | |
| 'training_rounds': self.total_training_cycles, | |
| 'ppo_triggered': self.ppo_training_triggered, | |
| 'last_training': self.last_training_timestamp, | |
| 'next_training_needs': max(0, self.new_feedback_threshold - new_feedback_count) | |
| } | |
| # Add reward model accuracy if available | |
| if self.reward_model.training_scores: | |
| stats['latest_accuracy'] = self.reward_model.training_scores[-1] | |
| stats['average_accuracy'] = np.mean(self.reward_model.training_scores) | |
| return stats | |
| def debug_buffer_status(self): | |
| """Debug function to check buffer status""" | |
| print(f"π Buffer Debug:") | |
| print(f" - Total feedback: {len(self.feedback_buffer)}") | |
| print(f" - New feedback since training: {self.count_new_feedback()}") | |
| print(f" - Last training: {self.last_training_timestamp}") | |
| print(f" - PPO triggered: {self.ppo_training_triggered}") | |
| print(f" - Ready for training: {len(self.feedback_buffer) >= 10 and self.count_new_feedback() >= self.new_feedback_threshold}") | |
| print(f" - Recent feedback timestamps:") | |
| for i, fb in enumerate(self.feedback_buffer[-5:]): | |
| print(f" {i}: {fb.get('timestamp', 'no timestamp')}") | |
| # β Initialize lightweight system | |
| print("π Initializing ultra-lightweight RLHF system with voice support...") | |
| rlhf_trainer = LightweightRLHFTrainer(MODEL_PATH, tokenizer, device) | |
| rlhf_trainer.load_progress() | |
| # β Response generation (MODIFIED to handle voice input) | |
| def generate_response(audio_input, text_input, age, chat_history): | |
| # NEW: Process voice input first, fallback to text | |
| query = process_voice_input(audio_input, text_input) | |
| if not query.strip() or not age: | |
| return chat_history + [("System: Please provide both age and a question (via voice or text).", "")], "", "", gr.update(visible=False, value=None), "", "", "" | |
| age_context = get_age_context(str(age)) | |
| is_followup = is_followup_question(query) | |
| # Get previous user input for follow-up context | |
| previous_user_question = "" | |
| for turn in reversed(chat_history): | |
| if "User" in turn[0]: | |
| match = re.search(r'User\(Age \d+\):\s*(.*)', turn[0]) | |
| if match: | |
| previous_user_question = match.group(1).strip() | |
| break | |
| search_query = previous_user_question + " " + query if is_followup else query | |
| retrieved_docs = retrieve_context(search_query, top_k=3) # Reduced for efficiency | |
| retrieved_knowledge = "\n".join(retrieved_docs) | |
| few_shot_examples = get_few_shot_examples(age) | |
| # Construct prior exchange for prompt | |
| last_user = "" | |
| last_assistant = "" | |
| if is_followup and len(chat_history) >= 2: | |
| if len(chat_history) >= 4 and "User" in chat_history[-4][0] and "Assistant" in chat_history[-3][0]: | |
| match_user = re.search(r'User\(Age \d+\):\s*(.*)', chat_history[-4][0]) | |
| match_assistant = re.search(r'Assistant:\s*(.*)', chat_history[-3][0]) | |
| if match_user and match_assistant: | |
| last_user = match_user.group(1).strip() | |
| last_assistant = match_assistant.group(1).strip() | |
| elif len(chat_history) >= 2 and "User" in chat_history[-2][0]: | |
| match_user = re.search(r'User\(Age \d+\):\s*(.*)', chat_history[-2][0]) | |
| if match_user: | |
| last_user = match_user.group(1).strip() | |
| prompt = ( | |
| few_shot_examples + "\n" | |
| "You are a safe and supportive assistant trained to teach online safety.\n" | |
| f"Based on the user's age: {age_context}\n" | |
| "Adjust your explanation style, vocabulary, and level of detail accordingly.\n" | |
| "Keep the conversation age-appropriate, empathetic, and accurate.\n" | |
| "Use previous conversation only for context. Do NOT repeat earlier responses.\n" | |
| + (f"\nPrior exchange:\nUser: {last_user}\nAssistant: {last_assistant}\n" if is_followup and last_user else "") | |
| + f"\nRelevant safety guideline:\n{retrieved_knowledge}\n" | |
| f"\nCurrent question from a user aged {age}:\n{query}\nAssistant:" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, # Reduced for efficiency | |
| do_sample=False, | |
| temperature=0.6, | |
| top_k=40, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = decoded.split("Assistant:")[-1].strip() | |
| # Clean up GPU memory after generation | |
| del inputs, outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # NEW: Show input method in chat history | |
| input_method = "π€ Voice" if audio_input is not None and SPEECH_AVAILABLE else "π¬ Text" | |
| chat_history.append((f"π§ User(Age {age}) [{input_method}]: {query}", "")) | |
| chat_history.append((f"π€ Assistant: {response}", "")) | |
| # NEW: Return empty strings for both audio and text inputs to clear them | |
| return chat_history, "", "", gr.update(visible=True, value=None), query, response, str(age) | |
| # β Enhanced feedback function with debug | |
| def save_feedback_with_rlhf(query, response, feedback, age): | |
| """FIXED: Lightweight feedback processing with better tracking""" | |
| if not query or not response or not feedback: | |
| return "Feedback requires a question, response, and selection.", rlhf_trainer.get_training_stats(), gr.update(value=None) | |
| try: | |
| # Add feedback to system | |
| rlhf_trainer.add_feedback(query, response, feedback, age) | |
| # Get updated stats | |
| stats = rlhf_trainer.get_training_stats() | |
| # Create detailed status message | |
| emoji = "π" if feedback == "π Like" else "π" | |
| status_msg = f"{emoji} Feedback saved! Total: {stats['total_feedback']}" | |
| if stats['ready_for_training']: | |
| status_msg += f" π Ready for training!" | |
| elif stats['next_training_needs'] > 0: | |
| status_msg += f" (Need {stats['next_training_needs']} more for next training)" | |
| print(f"β Feedback processed: {feedback}") | |
| rlhf_trainer.debug_buffer_status() | |
| return status_msg, stats, gr.update(value=None) | |
| except Exception as e: | |
| error_msg = f"β Error saving feedback: {str(e)}" | |
| print(error_msg) | |
| return error_msg, rlhf_trainer.get_training_stats(), gr.update(value=None) | |
| # β Training statistics formatter | |
| def format_training_stats(stats): | |
| """Format training statistics for display""" | |
| if not stats: | |
| return "No training statistics available." | |
| lines = [ | |
| "π **Training Statistics**", | |
| f"β’ Total Feedback: {stats.get('total_feedback', 0)}", | |
| f"β’ Positive: {stats.get('positive_feedback', 0)} | Negative: {stats.get('negative_feedback', 0)}", | |
| f"β’ Satisfaction Rate: {stats.get('satisfaction_rate', 0):.1%}", | |
| f"β’ Training Rounds: {stats.get('training_rounds', 0)}", | |
| f"β’ New Feedback Since Training: {stats.get('new_feedback_since_training', 0)}", | |
| ] | |
| if stats.get('reward_model_trained', False): | |
| lines.append(f"β’ Reward Model: β Trained") | |
| if 'latest_accuracy' in stats: | |
| lines.append(f"β’ Latest Accuracy: {stats['latest_accuracy']:.3f}") | |
| if 'average_accuracy' in stats: | |
| lines.append(f"β’ Average Accuracy: {stats['average_accuracy']:.3f}") | |
| else: | |
| lines.append("β’ Reward Model: β Not trained yet") | |
| if stats.get('ready_for_training', False): | |
| lines.append("π **Ready for next training cycle!**") | |
| else: | |
| needs = stats.get('next_training_needs', 0) | |
| if needs > 0: | |
| lines.append(f"β³ Need {needs} more feedback for training") | |
| return "\n".join(lines) | |
| # β Clear chat history function | |
| def clear_chat(): | |
| """Clear the chat history and reset states""" | |
| return [], "", "", gr.update(visible=False, value=None), "", "", "" | |
| # β Manual training trigger function | |
| def trigger_manual_training(): | |
| """Manually trigger training if conditions are met""" | |
| try: | |
| stats = rlhf_trainer.get_training_stats() | |
| if stats['total_feedback'] < 5: | |
| return "β Need at least 5 feedback entries to train", stats | |
| if rlhf_trainer.ppo_training_triggered: | |
| return "β³ Training already in progress", stats | |
| # Force training by temporarily lowering thresholds | |
| original_threshold = rlhf_trainer.new_feedback_threshold | |
| rlhf_trainer.new_feedback_threshold = 1 | |
| print("π― Manual training triggered...") | |
| training_result = rlhf_trainer.train_reward_model() | |
| if training_result.get('success', False): | |
| rlhf_trainer.minimal_model_update() | |
| rlhf_trainer.last_training_timestamp = datetime.now().isoformat() | |
| rlhf_trainer.total_training_cycles += 1 | |
| rlhf_trainer.save_progress() | |
| rlhf_trainer.save_training_state() | |
| message = f"β Manual training completed! Accuracy: {training_result.get('accuracy', 0):.3f}" | |
| else: | |
| message = f"β Training failed: {training_result.get('error', 'Unknown error')}" | |
| # Restore original threshold | |
| rlhf_trainer.new_feedback_threshold = original_threshold | |
| return message, rlhf_trainer.get_training_stats() | |
| except Exception as e: | |
| return f"β Manual training error: {str(e)}", rlhf_trainer.get_training_stats() | |
| # β Export training data function | |
| def export_training_data(): | |
| """Export current training data for analysis""" | |
| try: | |
| if not rlhf_trainer.feedback_buffer: | |
| return "No training data to export." | |
| # Create export data | |
| export_data = { | |
| 'feedback_buffer': rlhf_trainer.feedback_buffer, | |
| 'training_history': rlhf_trainer.training_history, | |
| 'stats': rlhf_trainer.get_training_stats(), | |
| 'exported_at': datetime.now().isoformat() | |
| } | |
| # Save to file | |
| export_path = os.path.join(RLHF_SAVE_PATH, f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") | |
| os.makedirs(RLHF_SAVE_PATH, exist_ok=True) | |
| with open(export_path, 'w') as f: | |
| json.dump(export_data, f, indent=2) | |
| return f"β Training data exported to: {export_path}" | |
| except Exception as e: | |
| return f"β Export error: {str(e)}" | |
| # β Refresh stats function (replaces the problematic 'every' parameter) | |
| def refresh_stats(): | |
| """Manually refresh statistics""" | |
| return format_training_stats(rlhf_trainer.get_training_stats()) | |
| # β Create Gradio Interface with Enhanced Features (FIXED) | |
| print("π¨ Creating enhanced Gradio interface...") | |
| with gr.Blocks( | |
| title="π‘οΈ Child Safety Assistant with Voice & RLHF", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-container { max-width: 1200px; margin: 0 auto; } | |
| .chat-container { height: 500px; } | |
| .feedback-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 10px; padding: 20px; margin: 10px 0; } | |
| .stats-container { background: #f8f9fa; border-radius: 8px; padding: 15px; margin: 10px 0; } | |
| .voice-input { border: 2px dashed #007bff; border-radius: 10px; padding: 15px; } | |
| """ | |
| ) as demo: | |
| # β State variables for tracking | |
| current_query = gr.State("") | |
| current_response = gr.State("") | |
| current_age = gr.State("") | |
| # β Header | |
| gr.HTML(f""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>π‘οΈ Child Safety Assistant</h1> | |
| <p><strong>AI-Powered Online Safety Education with Voice Input & RLHF Learning</strong></p> | |
| <p>{"π€ Voice input available" if SPEECH_AVAILABLE else "β οΈ Voice input disabled (install SpeechRecognition)"} | π§ Learns from your feedback</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # β Main Chat Interface | |
| gr.HTML("<h3>π¬ Chat with Safety Assistant</h3>") | |
| # Age input | |
| age_input = gr.Number( | |
| label="πΆ Your Age", | |
| value=12, | |
| minimum=5, | |
| maximum=25, | |
| info="Helps customize responses for your age level" | |
| ) | |
| # Voice input section (only show if available) | |
| if SPEECH_AVAILABLE: | |
| with gr.Group(): | |
| gr.HTML("<h4>π€ Voice Input (Optional)</h4>") | |
| audio_input = gr.Audio( | |
| label="Speak your question", | |
| sources=["microphone"], | |
| type="filepath" | |
| ) | |
| else: | |
| audio_input = gr.Audio(visible=False) | |
| gr.HTML("<p>β οΈ Voice input disabled. Install SpeechRecognition to enable.</p>") | |
| # Text input as fallback | |
| text_input = gr.Textbox( | |
| label="π¬ Type your question" + (" (or use voice above)" if SPEECH_AVAILABLE else ""), | |
| placeholder="Ask about online safety, cyberbullying, privacy, etc...", | |
| lines=2 | |
| ) | |
| # Chat display | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=400, | |
| avatar_images=(None, "π€") | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Ask Question", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| # β Feedback & Training Panel | |
| gr.HTML("<h3>π― Feedback & Learning</h3>") | |
| # Current conversation feedback | |
| with gr.Group(): | |
| feedback_radio = gr.Radio( | |
| choices=["π Like", "π Dislike"], | |
| label="Rate the last response", | |
| info="Your feedback helps improve the AI", | |
| visible=False | |
| ) | |
| feedback_btn = gr.Button("πΎ Submit Feedback", variant="secondary", visible=False) | |
| feedback_status = gr.Textbox( | |
| label="Feedback Status", | |
| interactive=False, | |
| max_lines=2 | |
| ) | |
| # Training statistics | |
| gr.HTML("<h4>π Training Statistics</h4>") | |
| stats_display = gr.Markdown("Loading statistics...") | |
| # Refresh button (replaces automatic refresh) | |
| refresh_btn = gr.Button("π Refresh Stats", variant="secondary", size="sm") | |
| # Advanced controls | |
| with gr.Accordion("π§ Advanced Controls", open=False): | |
| manual_train_btn = gr.Button("π Manual Training", variant="secondary") | |
| export_btn = gr.Button("π€ Export Data", variant="secondary") | |
| training_status = gr.Textbox( | |
| label="Training Status", | |
| interactive=False, | |
| max_lines=3 | |
| ) | |
| # β Event Handlers | |
| # Main conversation flow | |
| submit_btn.click( | |
| generate_response, | |
| inputs=[audio_input, text_input, age_input, chatbot], | |
| outputs=[chatbot, text_input, audio_input, feedback_radio, current_query, current_response, current_age], | |
| show_progress=True | |
| ).then( | |
| lambda: (gr.update(visible=True), gr.update(visible=True)), # Show feedback components | |
| outputs=[feedback_radio, feedback_btn] | |
| ) | |
| # Clear chat | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, text_input, audio_input, feedback_radio, current_query, current_response, current_age] | |
| ).then( | |
| lambda: (gr.update(visible=False), gr.update(visible=False)), # Hide feedback components | |
| outputs=[feedback_radio, feedback_btn] | |
| ) | |
| # Feedback submission | |
| feedback_btn.click( | |
| save_feedback_with_rlhf, | |
| inputs=[current_query, current_response, feedback_radio, current_age], | |
| outputs=[feedback_status, gr.State(), feedback_radio], | |
| show_progress=True | |
| ).then( | |
| refresh_stats, | |
| outputs=[stats_display] | |
| ).then( | |
| lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(value=None)), # Hide and reset feedback | |
| outputs=[feedback_radio, feedback_btn, feedback_radio] | |
| ) | |
| # Manual training | |
| manual_train_btn.click( | |
| trigger_manual_training, | |
| outputs=[training_status, gr.State()], | |
| show_progress=True | |
| ).then( | |
| refresh_stats, | |
| outputs=[stats_display] | |
| ) | |
| # Export data | |
| export_btn.click( | |
| export_training_data, | |
| outputs=[training_status] | |
| ) | |
| # Manual refresh stats | |
| refresh_btn.click( | |
| refresh_stats, | |
| outputs=[stats_display] | |
| ) | |
| # Initialize stats display on load | |
| demo.load( | |
| refresh_stats, | |
| outputs=[stats_display] | |
| ) | |
| # β Launch Configuration | |
| if __name__ == "__main__": | |
| print("π Launching Child Safety Chatbot with Voice & RLHF...") | |
| print("π Features enabled:") | |
| print(f" {'β ' if SPEECH_AVAILABLE else 'β'} Voice input with speech-to-text") | |
| print(" β Ultra-lightweight RLHF with PPO-style training") | |
| print(" β Scikit-learn based reward model") | |
| print(" β Memory-optimized for free Colab") | |
| print(" β Automatic model improvement from feedback") | |
| print(" β Manual training controls") | |
| print(" β Training data export") | |
| print(" β Real-time statistics tracking") | |
| print(" β FIXED: No problematic 'every' parameter") | |
| # Show current training stats | |
| current_stats = rlhf_trainer.get_training_stats() | |
| print(f"\nπ Current training stats:") | |
| print(f" - Total feedback: {current_stats['total_feedback']}") | |
| print(f" - Training rounds: {current_stats['training_rounds']}") | |
| print(f" - Ready for training: {current_stats['ready_for_training']}") | |
| print(f" - New feedback needed: {current_stats['next_training_needs']}") | |
| # Launch with appropriate settings | |
| demo.launch( | |
| server_name="0.0.0.0", # For Colab | |
| server_port=7860, | |
| share=True, # Creates public link | |
| debug=False, | |
| show_error=True, | |
| quiet=False | |
| ) |