Shilpagotur's picture
Update app.py
a439d62 verified
# βœ… Ultra-lightweight Gradio chatbot with PPO reward model for Free Colab
# Optimized to prevent RAM crashes with minimal memory footprint
# FIXED: Improved PPO trigger logic to handle continuous training
# NEW: Added voice input capability with speech-to-text
## FIXED: Removed problematic 'every' parameter from Gradio events
import json
import torch
import faiss
import numpy as np
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import pandas as pd
import os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datetime import datetime
import gc
import math
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import pickle
# NEW: Import for speech-to-text functionality (with error handling)
try:
import speech_recognition as sr
SPEECH_AVAILABLE = True
print("βœ… Speech recognition available")
except ImportError:
print("⚠️ Speech recognition not available. Install with: pip install SpeechRecognition")
SPEECH_AVAILABLE = False
try:
import librosa
LIBROSA_AVAILABLE = True
except ImportError:
print("⚠️ Librosa not available. Advanced audio processing disabled.")
LIBROSA_AVAILABLE = False
# βœ… Load model and tokenizer
MODEL_PATH = "Shilpagotur/UNICEF_chatbot"
DATA_PATH = "Shilpagotur/training_dataset_modified.json"
RLHF_SAVE_PATH = "Shilpagotur/rlhf_models" # Path to save RLHF models and data
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=dtype).to(device)
tokenizer.pad_token = tokenizer.eos_token
# βœ… Load RAG data
with open(DATA_PATH, "r") as f:
knowledge_base = json.load(f)
retrieval_texts = [item["answer"] for item in knowledge_base]
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
retrieval_embeddings = embedding_model.encode(retrieval_texts, convert_to_numpy=True)
# βœ… Build FAISS index
dimension = retrieval_embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(retrieval_embeddings)
# NEW: Initialize speech recognition (if available)
if SPEECH_AVAILABLE:
recognizer = sr.Recognizer()
# NEW: Speech-to-text function with error handling
def speech_to_text(audio_file):
"""Convert speech audio to text using speech_recognition library"""
if not SPEECH_AVAILABLE:
return "Speech recognition not available. Please install SpeechRecognition library."
if audio_file is None:
return ""
try:
print("🎀 Processing speech input...")
# Create recognizer instance
r = sr.Recognizer()
# Load audio file
with sr.AudioFile(audio_file) as source:
# Adjust for ambient noise and record
r.adjust_for_ambient_noise(source, duration=0.5)
audio_data = r.record(source)
# Recognize speech using Google Web Speech API (free tier)
try:
text = r.recognize_google(audio_data)
print(f"πŸŽ™οΈ Recognized speech: '{text}'")
return text
except sr.UnknownValueError:
print("❌ Could not understand the audio")
return "Sorry, I couldn't understand what you said. Please try again or use text input."
except sr.RequestError as e:
print(f"❌ Error with speech recognition service: {e}")
return "Speech recognition service is unavailable. Please use text input."
except Exception as e:
print(f"❌ Error processing audio: {e}")
return "Error processing audio. Please try again or use text input."
# NEW: Combined input processing function
def process_voice_input(audio_file, text_input):
"""Process voice input and convert to text, fallback to text input if needed"""
if audio_file is not None and SPEECH_AVAILABLE:
# If voice input is provided, convert it to text
voice_text = speech_to_text(audio_file)
if voice_text and not voice_text.startswith("Sorry") and not voice_text.startswith("Speech") and not voice_text.startswith("Error"):
return voice_text
# Fallback to text input if voice failed or not provided
return text_input if text_input else ""
# βœ… Helper functions (unchanged)
def retrieve_context(query, top_k=3):
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
D, I = faiss_index.search(query_embedding, top_k)
return [retrieval_texts[i] for i in I[0]]
def get_age_context(age):
age = int(age) if age.isdigit() else 12
if age <= 7:
return "You are a young child, so I will explain things simply, clearly, and empathetically."
elif age <= 12:
return "You are an older child, so I will explain with more detail while keeping it safe, empathetic, and understandable."
elif age <= 17:
return "You are a teenager, so I will provide more in-depth, more technical advice with caution, but in an empathetic way"
else:
return "The user is an adult. Use professional language and detailed safety advice."
def get_few_shot_examples(age):
age = int(age) if str(age).isdigit() else 12
if age <= 7:
return """
Example 1:
User: They said I can download points for free if I install a background app is it legit.
Assistant: That could be a trick. Don't install anything someone tells you without asking a grown-up. Talk to parent or guardian
"""
elif age <= 12:
return """
Example 1:
User: They said I can download points for free if I install a background app is it legit.
Assistant: Be careful! Sometimes people say things like free points to trick kids into installing bad apps. These apps can steal your stuff or break your device. Always ask your parents before downloading anything. If it seems too good to be true, it usually is.
"""
else:
return """
Example 1:
User: Someone blocked me online. Should I get revenge?
Assistant: Offers like free points if you install an app are often scams or phishing attempts. The background app could secretly collect your data, show you unsafe ads, or even install malware without you knowing. Legit apps don't offer rewards in exchange for random downloads. Always check reviews, publisher info, and permissions. It's best to ignore such offers and report them if possible. Stay safe and don't trade your privacy for fake rewards.
"""
def is_followup_question(question):
patterns = [
r'\b(what about|how about|what if|but what|also|and|additionally)\b',
r'\b(can you explain|tell me more|elaborate|expand)\b',
r'\b(why|how|when|where|who)\b.*\?',
r'\b(that|this|it|they|them)\b',
r'\b(previous|before|earlier|above)\b',
r'^(yes|no|okay|ok|sure|right|exactly|correct|wrong)'
]
return any(re.search(p, question.lower().strip()) for p in patterns)
# βœ… Ultra-Lightweight Reward Model using Scikit-learn (NO PYTORCH!)
class SklearnRewardModel:
"""Ultra-lightweight reward model using scikit-learn instead of PyTorch"""
def __init__(self):
# Use TF-IDF for text features (much lighter than transformers)
self.vectorizer = TfidfVectorizer(
max_features=500, # Very small feature set
stop_words='english',
ngram_range=(1, 2), # Unigrams and bigrams
max_df=0.95,
min_df=2
)
# Use Logistic Regression (very lightweight)
self.classifier = LogisticRegression(
random_state=42,
max_iter=100,
C=1.0
)
self.is_trained = False
self.training_scores = []
def preprocess_text(self, query, response):
"""Create features from query-response pair"""
# Combine query and response
combined_text = f"Q: {query[:200]} A: {response[:200]}"
return combined_text
def train(self, training_data):
"""Train the reward model on feedback data"""
if len(training_data) < 3:
return {'success': False, 'reason': 'Need at least 3 samples'}
try:
print(f"πŸš€ Training lightweight reward model with {len(training_data)} samples...")
# Prepare training data
texts = []
labels = []
for item in training_data:
text = self.preprocess_text(item['query'], item['response'])
label = 1 if item['feedback'] == "😊 Like" else 0
texts.append(text)
labels.append(label)
# Check if we have both positive and negative examples
if len(set(labels)) < 2:
return {'success': False, 'reason': 'Need both positive and negative examples'}
# Vectorize text
X = self.vectorizer.fit_transform(texts)
# Train classifier
self.classifier.fit(X, labels)
self.is_trained = True
# Calculate training accuracy
train_score = self.classifier.score(X, labels)
self.training_scores.append(train_score)
print(f"βœ… Training completed! Accuracy: {train_score:.3f}")
return {
'success': True,
'accuracy': train_score,
'samples_used': len(training_data)
}
except Exception as e:
print(f"❌ Training error: {e}")
return {'success': False, 'error': str(e)}
def predict_reward(self, query, response):
"""Predict reward for query-response pair"""
if not self.is_trained:
return 0.5 # Neutral score if not trained
try:
text = self.preprocess_text(query, response)
X = self.vectorizer.transform([text])
# Get probability of positive class
prob = self.classifier.predict_proba(X)[0][1]
return float(prob)
except Exception as e:
print(f"Prediction error: {e}")
return 0.5
def save_model(self, filepath):
"""Save the trained model"""
try:
model_data = {
'vectorizer': self.vectorizer,
'classifier': self.classifier,
'is_trained': self.is_trained,
'training_scores': self.training_scores
}
with open(filepath, 'wb') as f:
pickle.dump(model_data, f)
return True
except Exception as e:
print(f"Save error: {e}")
return False
def load_model(self, filepath):
"""Load a trained model"""
try:
if os.path.exists(filepath):
with open(filepath, 'rb') as f:
model_data = pickle.load(f)
self.vectorizer = model_data['vectorizer']
self.classifier = model_data['classifier']
self.is_trained = model_data['is_trained']
self.training_scores = model_data.get('training_scores', [])
print(f"βœ… Reward model loaded! Trained: {self.is_trained}")
return True
except Exception as e:
print(f"Load error: {e}")
return False
# βœ… FIXED: Improved PPO-style Trainer with Better Trigger Logic
class LightweightRLHFTrainer:
"""Ultra-lightweight RLHF trainer for free Colab with FIXED trigger logic"""
def __init__(self, model_path, tokenizer, device='cuda'):
self.device = device
self.tokenizer = tokenizer
# Keep reference to main model (but don't modify it heavily)
self.main_model = model
# Initialize lightweight reward model
self.reward_model = SklearnRewardModel()
# Minimal optimizer for main model (only used sparingly)
self.main_optimizer = optim.AdamW(
self.main_model.parameters(),
lr=5e-6, # Very small learning rate
weight_decay=0.01
)
# Training data storage
self.feedback_buffer = []
self.training_history = []
self.ppo_training_triggered = False
# πŸ”§ NEW: Track training state
self.last_training_timestamp = "1970-01-01T00:00:00" # Start with epoch
self.total_training_cycles = 0
self.new_feedback_threshold = 5 # Need 5 new feedbacks for retraining
# Load any existing model
reward_model_path = os.path.join(RLHF_SAVE_PATH, "lightweight_reward_model.pkl")
self.reward_model.load_model(reward_model_path)
# Load existing training state
self.load_training_state()
def load_training_state(self):
"""Load previous training state"""
try:
state_path = os.path.join(RLHF_SAVE_PATH, "training_state.json")
if os.path.exists(state_path):
with open(state_path, 'r') as f:
state = json.load(f)
self.last_training_timestamp = state.get('last_training_timestamp', "1970-01-01T00:00:00")
self.total_training_cycles = state.get('total_training_cycles', 0)
print(f"βœ… Loaded training state: {self.total_training_cycles} cycles, last training: {self.last_training_timestamp}")
except Exception as e:
print(f"⚠️ Could not load training state: {e}")
def save_training_state(self):
"""Save current training state"""
try:
os.makedirs(RLHF_SAVE_PATH, exist_ok=True)
state_path = os.path.join(RLHF_SAVE_PATH, "training_state.json")
state = {
'last_training_timestamp': self.last_training_timestamp,
'total_training_cycles': self.total_training_cycles,
'saved_at': datetime.now().isoformat()
}
with open(state_path, 'w') as f:
json.dump(state, f, indent=2)
except Exception as e:
print(f"⚠️ Could not save training state: {e}")
def count_new_feedback(self):
"""Count feedback received since last training"""
new_count = 0
for feedback in self.feedback_buffer:
feedback_time = feedback.get('timestamp', "1970-01-01T00:00:00")
if feedback_time > self.last_training_timestamp:
new_count += 1
return new_count
def add_feedback(self, query, response, feedback, age=None):
"""Add feedback to buffer with IMPROVED trigger logic"""
feedback_item = {
'query': query,
'response': response,
'feedback': feedback,
'age': age,
'timestamp': datetime.now().isoformat()
}
self.feedback_buffer.append(feedback_item)
# πŸ”§ IMPROVED: Better trigger logic
buffer_size = len(self.feedback_buffer)
new_feedback_count = self.count_new_feedback()
# Trigger conditions:
# 1. Have at least 10 total feedbacks AND
# 2. Have at least 5 new feedbacks since last training AND
# 3. Not currently training
should_trigger = (
buffer_size >= 10 and
new_feedback_count >= self.new_feedback_threshold and
not self.ppo_training_triggered
)
print(f"πŸ“Š Feedback stats: Total={buffer_size}, New={new_feedback_count}, Should trigger={should_trigger}")
if should_trigger:
try:
print(f"\nπŸš€ Triggering PPO training cycle #{self.total_training_cycles + 1}...")
print(f" - Total feedback: {buffer_size}")
print(f" - New feedback since last training: {new_feedback_count}")
self.ppo_training_triggered = True
# Train reward model
training_result = self.train_reward_model()
if training_result.get('success', False):
print("βœ… Reward model training successful!")
# Minimal main model update (very conservative)
self.minimal_model_update()
# Update training state
self.last_training_timestamp = datetime.now().isoformat()
self.total_training_cycles += 1
# Save progress
self.save_progress()
self.save_training_state()
else:
print("⚠️ Reward model training failed")
self.ppo_training_triggered = False
print(f"--- Training cycle #{self.total_training_cycles} complete ---\n")
except Exception as e:
print(f"❌ Training error: {e}")
self.ppo_training_triggered = False
def train_reward_model(self):
"""Train the lightweight reward model"""
return self.reward_model.train(self.feedback_buffer)
def minimal_model_update(self):
"""Very minimal update to main model to avoid OOM"""
try:
print("πŸ”„ Performing minimal model update...")
# Only process the best examples (highest predicted rewards)
positive_samples = [f for f in self.feedback_buffer if f['feedback'] == "😊 Like"]
if not positive_samples:
return
# Score and select only the top examples
scored_samples = []
for sample in positive_samples:
reward = self.reward_model.predict_reward(sample['query'], sample['response'])
if reward > 0.8: # Only very high-confidence examples
scored_samples.append((sample, reward))
# Sort by reward and take only top 2 to avoid memory issues
scored_samples.sort(key=lambda x: x[1], reverse=True)
top_samples = scored_samples[:2]
if not top_samples:
print("No high-quality samples found for update")
return
# Very minimal training step
self.main_model.train()
for sample, reward in top_samples:
try:
# Create minimal prompt
prompt = f"User: {sample['query'][:100]}\nAssistant: {sample['response'][:100]}"
# Tokenize with small length
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=128,
truncation=True,
padding=True
).to(self.device)
# Single forward pass
outputs = self.main_model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
if loss is not None and not torch.isnan(loss):
# Scale by reward and make update very small
scaled_loss = loss * reward * 0.1 # Very conservative scaling
self.main_optimizer.zero_grad()
scaled_loss.backward()
# Very conservative gradient clipping
torch.nn.utils.clip_grad_norm_(self.main_model.parameters(), max_norm=0.1)
self.main_optimizer.step()
# Immediate cleanup
del inputs, outputs, loss
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Error in sample update: {e}")
continue
print(f"βœ… Updated model with {len(top_samples)} high-quality samples")
except Exception as e:
print(f"❌ Model update error: {e}")
finally:
# Cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Keep more feedback but limit to prevent memory issues
if len(self.feedback_buffer) > 50:
self.feedback_buffer = self.feedback_buffer[-50:] # Keep last 50
def predict_reward(self, query, response):
"""Predict reward using lightweight model"""
return self.reward_model.predict_reward(query, response)
def save_progress(self):
"""Save training progress"""
try:
os.makedirs(RLHF_SAVE_PATH, exist_ok=True)
# Save reward model
reward_model_path = os.path.join(RLHF_SAVE_PATH, "lightweight_reward_model.pkl")
self.reward_model.save_model(reward_model_path)
# Save feedback buffer
feedback_path = os.path.join(RLHF_SAVE_PATH, "feedback_buffer.json")
with open(feedback_path, 'w') as f:
json.dump(self.feedback_buffer, f, indent=2)
# Save training history
history_path = os.path.join(RLHF_SAVE_PATH, "training_history.json")
training_entry = {
'timestamp': datetime.now().isoformat(),
'feedback_count': len(self.feedback_buffer),
'reward_scores': self.reward_model.training_scores,
'training_cycle': self.total_training_cycles
}
self.training_history.append(training_entry)
with open(history_path, 'w') as f:
json.dump(self.training_history, f, indent=2)
print("βœ… Progress saved successfully")
except Exception as e:
print(f"❌ Save error: {e}")
def load_progress(self):
"""Load saved progress"""
try:
# Load feedback buffer
feedback_path = os.path.join(RLHF_SAVE_PATH, "feedback_buffer.json")
if os.path.exists(feedback_path):
with open(feedback_path, 'r') as f:
self.feedback_buffer = json.load(f)
print(f"βœ… Loaded {len(self.feedback_buffer)} feedback entries")
# Load training history
history_path = os.path.join(RLHF_SAVE_PATH, "training_history.json")
if os.path.exists(history_path):
with open(history_path, 'r') as f:
self.training_history = json.load(f)
print(f"βœ… Loaded {len(self.training_history)} training history entries")
except Exception as e:
print(f"❌ Load error: {e}")
def get_training_stats(self):
"""Get comprehensive training statistics with FIXED logic"""
positive_feedback = sum(1 for f in self.feedback_buffer if f['feedback'] == "😊 Like")
negative_feedback = len(self.feedback_buffer) - positive_feedback
new_feedback_count = self.count_new_feedback()
stats = {
'total_feedback': len(self.feedback_buffer),
'positive_feedback': positive_feedback,
'negative_feedback': negative_feedback,
'satisfaction_rate': positive_feedback / len(self.feedback_buffer) if self.feedback_buffer else 0,
'new_feedback_since_training': new_feedback_count,
'ready_for_training': len(self.feedback_buffer) >= 10 and new_feedback_count >= self.new_feedback_threshold,
'reward_model_trained': self.reward_model.is_trained,
'training_rounds': self.total_training_cycles,
'ppo_triggered': self.ppo_training_triggered,
'last_training': self.last_training_timestamp,
'next_training_needs': max(0, self.new_feedback_threshold - new_feedback_count)
}
# Add reward model accuracy if available
if self.reward_model.training_scores:
stats['latest_accuracy'] = self.reward_model.training_scores[-1]
stats['average_accuracy'] = np.mean(self.reward_model.training_scores)
return stats
def debug_buffer_status(self):
"""Debug function to check buffer status"""
print(f"πŸ“Š Buffer Debug:")
print(f" - Total feedback: {len(self.feedback_buffer)}")
print(f" - New feedback since training: {self.count_new_feedback()}")
print(f" - Last training: {self.last_training_timestamp}")
print(f" - PPO triggered: {self.ppo_training_triggered}")
print(f" - Ready for training: {len(self.feedback_buffer) >= 10 and self.count_new_feedback() >= self.new_feedback_threshold}")
print(f" - Recent feedback timestamps:")
for i, fb in enumerate(self.feedback_buffer[-5:]):
print(f" {i}: {fb.get('timestamp', 'no timestamp')}")
# βœ… Initialize lightweight system
print("πŸš€ Initializing ultra-lightweight RLHF system with voice support...")
rlhf_trainer = LightweightRLHFTrainer(MODEL_PATH, tokenizer, device)
rlhf_trainer.load_progress()
# βœ… Response generation (MODIFIED to handle voice input)
def generate_response(audio_input, text_input, age, chat_history):
# NEW: Process voice input first, fallback to text
query = process_voice_input(audio_input, text_input)
if not query.strip() or not age:
return chat_history + [("System: Please provide both age and a question (via voice or text).", "")], "", "", gr.update(visible=False, value=None), "", "", ""
age_context = get_age_context(str(age))
is_followup = is_followup_question(query)
# Get previous user input for follow-up context
previous_user_question = ""
for turn in reversed(chat_history):
if "User" in turn[0]:
match = re.search(r'User\(Age \d+\):\s*(.*)', turn[0])
if match:
previous_user_question = match.group(1).strip()
break
search_query = previous_user_question + " " + query if is_followup else query
retrieved_docs = retrieve_context(search_query, top_k=3) # Reduced for efficiency
retrieved_knowledge = "\n".join(retrieved_docs)
few_shot_examples = get_few_shot_examples(age)
# Construct prior exchange for prompt
last_user = ""
last_assistant = ""
if is_followup and len(chat_history) >= 2:
if len(chat_history) >= 4 and "User" in chat_history[-4][0] and "Assistant" in chat_history[-3][0]:
match_user = re.search(r'User\(Age \d+\):\s*(.*)', chat_history[-4][0])
match_assistant = re.search(r'Assistant:\s*(.*)', chat_history[-3][0])
if match_user and match_assistant:
last_user = match_user.group(1).strip()
last_assistant = match_assistant.group(1).strip()
elif len(chat_history) >= 2 and "User" in chat_history[-2][0]:
match_user = re.search(r'User\(Age \d+\):\s*(.*)', chat_history[-2][0])
if match_user:
last_user = match_user.group(1).strip()
prompt = (
few_shot_examples + "\n"
"You are a safe and supportive assistant trained to teach online safety.\n"
f"Based on the user's age: {age_context}\n"
"Adjust your explanation style, vocabulary, and level of detail accordingly.\n"
"Keep the conversation age-appropriate, empathetic, and accurate.\n"
"Use previous conversation only for context. Do NOT repeat earlier responses.\n"
+ (f"\nPrior exchange:\nUser: {last_user}\nAssistant: {last_assistant}\n" if is_followup and last_user else "")
+ f"\nRelevant safety guideline:\n{retrieved_knowledge}\n"
f"\nCurrent question from a user aged {age}:\n{query}\nAssistant:"
)
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200, # Reduced for efficiency
do_sample=False,
temperature=0.6,
top_k=40,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = decoded.split("Assistant:")[-1].strip()
# Clean up GPU memory after generation
del inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# NEW: Show input method in chat history
input_method = "🎀 Voice" if audio_input is not None and SPEECH_AVAILABLE else "πŸ’¬ Text"
chat_history.append((f"πŸ§‘ User(Age {age}) [{input_method}]: {query}", ""))
chat_history.append((f"πŸ€– Assistant: {response}", ""))
# NEW: Return empty strings for both audio and text inputs to clear them
return chat_history, "", "", gr.update(visible=True, value=None), query, response, str(age)
# βœ… Enhanced feedback function with debug
def save_feedback_with_rlhf(query, response, feedback, age):
"""FIXED: Lightweight feedback processing with better tracking"""
if not query or not response or not feedback:
return "Feedback requires a question, response, and selection.", rlhf_trainer.get_training_stats(), gr.update(value=None)
try:
# Add feedback to system
rlhf_trainer.add_feedback(query, response, feedback, age)
# Get updated stats
stats = rlhf_trainer.get_training_stats()
# Create detailed status message
emoji = "πŸ‘" if feedback == "😊 Like" else "πŸ‘Ž"
status_msg = f"{emoji} Feedback saved! Total: {stats['total_feedback']}"
if stats['ready_for_training']:
status_msg += f" πŸš€ Ready for training!"
elif stats['next_training_needs'] > 0:
status_msg += f" (Need {stats['next_training_needs']} more for next training)"
print(f"βœ… Feedback processed: {feedback}")
rlhf_trainer.debug_buffer_status()
return status_msg, stats, gr.update(value=None)
except Exception as e:
error_msg = f"❌ Error saving feedback: {str(e)}"
print(error_msg)
return error_msg, rlhf_trainer.get_training_stats(), gr.update(value=None)
# βœ… Training statistics formatter
def format_training_stats(stats):
"""Format training statistics for display"""
if not stats:
return "No training statistics available."
lines = [
"πŸ“Š **Training Statistics**",
f"β€’ Total Feedback: {stats.get('total_feedback', 0)}",
f"β€’ Positive: {stats.get('positive_feedback', 0)} | Negative: {stats.get('negative_feedback', 0)}",
f"β€’ Satisfaction Rate: {stats.get('satisfaction_rate', 0):.1%}",
f"β€’ Training Rounds: {stats.get('training_rounds', 0)}",
f"β€’ New Feedback Since Training: {stats.get('new_feedback_since_training', 0)}",
]
if stats.get('reward_model_trained', False):
lines.append(f"β€’ Reward Model: βœ… Trained")
if 'latest_accuracy' in stats:
lines.append(f"β€’ Latest Accuracy: {stats['latest_accuracy']:.3f}")
if 'average_accuracy' in stats:
lines.append(f"β€’ Average Accuracy: {stats['average_accuracy']:.3f}")
else:
lines.append("β€’ Reward Model: ❌ Not trained yet")
if stats.get('ready_for_training', False):
lines.append("πŸš€ **Ready for next training cycle!**")
else:
needs = stats.get('next_training_needs', 0)
if needs > 0:
lines.append(f"⏳ Need {needs} more feedback for training")
return "\n".join(lines)
# βœ… Clear chat history function
def clear_chat():
"""Clear the chat history and reset states"""
return [], "", "", gr.update(visible=False, value=None), "", "", ""
# βœ… Manual training trigger function
def trigger_manual_training():
"""Manually trigger training if conditions are met"""
try:
stats = rlhf_trainer.get_training_stats()
if stats['total_feedback'] < 5:
return "❌ Need at least 5 feedback entries to train", stats
if rlhf_trainer.ppo_training_triggered:
return "⏳ Training already in progress", stats
# Force training by temporarily lowering thresholds
original_threshold = rlhf_trainer.new_feedback_threshold
rlhf_trainer.new_feedback_threshold = 1
print("🎯 Manual training triggered...")
training_result = rlhf_trainer.train_reward_model()
if training_result.get('success', False):
rlhf_trainer.minimal_model_update()
rlhf_trainer.last_training_timestamp = datetime.now().isoformat()
rlhf_trainer.total_training_cycles += 1
rlhf_trainer.save_progress()
rlhf_trainer.save_training_state()
message = f"βœ… Manual training completed! Accuracy: {training_result.get('accuracy', 0):.3f}"
else:
message = f"❌ Training failed: {training_result.get('error', 'Unknown error')}"
# Restore original threshold
rlhf_trainer.new_feedback_threshold = original_threshold
return message, rlhf_trainer.get_training_stats()
except Exception as e:
return f"❌ Manual training error: {str(e)}", rlhf_trainer.get_training_stats()
# βœ… Export training data function
def export_training_data():
"""Export current training data for analysis"""
try:
if not rlhf_trainer.feedback_buffer:
return "No training data to export."
# Create export data
export_data = {
'feedback_buffer': rlhf_trainer.feedback_buffer,
'training_history': rlhf_trainer.training_history,
'stats': rlhf_trainer.get_training_stats(),
'exported_at': datetime.now().isoformat()
}
# Save to file
export_path = os.path.join(RLHF_SAVE_PATH, f"export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
os.makedirs(RLHF_SAVE_PATH, exist_ok=True)
with open(export_path, 'w') as f:
json.dump(export_data, f, indent=2)
return f"βœ… Training data exported to: {export_path}"
except Exception as e:
return f"❌ Export error: {str(e)}"
# βœ… Refresh stats function (replaces the problematic 'every' parameter)
def refresh_stats():
"""Manually refresh statistics"""
return format_training_stats(rlhf_trainer.get_training_stats())
# βœ… Create Gradio Interface with Enhanced Features (FIXED)
print("🎨 Creating enhanced Gradio interface...")
with gr.Blocks(
title="πŸ›‘οΈ Child Safety Assistant with Voice & RLHF",
theme=gr.themes.Soft(),
css="""
.main-container { max-width: 1200px; margin: 0 auto; }
.chat-container { height: 500px; }
.feedback-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 10px; padding: 20px; margin: 10px 0; }
.stats-container { background: #f8f9fa; border-radius: 8px; padding: 15px; margin: 10px 0; }
.voice-input { border: 2px dashed #007bff; border-radius: 10px; padding: 15px; }
"""
) as demo:
# βœ… State variables for tracking
current_query = gr.State("")
current_response = gr.State("")
current_age = gr.State("")
# βœ… Header
gr.HTML(f"""
<div style="text-align: center; margin-bottom: 20px;">
<h1>πŸ›‘οΈ Child Safety Assistant</h1>
<p><strong>AI-Powered Online Safety Education with Voice Input & RLHF Learning</strong></p>
<p>{"🎀 Voice input available" if SPEECH_AVAILABLE else "⚠️ Voice input disabled (install SpeechRecognition)"} | 🧠 Learns from your feedback</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
# βœ… Main Chat Interface
gr.HTML("<h3>πŸ’¬ Chat with Safety Assistant</h3>")
# Age input
age_input = gr.Number(
label="πŸ‘Ά Your Age",
value=12,
minimum=5,
maximum=25,
info="Helps customize responses for your age level"
)
# Voice input section (only show if available)
if SPEECH_AVAILABLE:
with gr.Group():
gr.HTML("<h4>🎀 Voice Input (Optional)</h4>")
audio_input = gr.Audio(
label="Speak your question",
sources=["microphone"],
type="filepath"
)
else:
audio_input = gr.Audio(visible=False)
gr.HTML("<p>⚠️ Voice input disabled. Install SpeechRecognition to enable.</p>")
# Text input as fallback
text_input = gr.Textbox(
label="πŸ’¬ Type your question" + (" (or use voice above)" if SPEECH_AVAILABLE else ""),
placeholder="Ask about online safety, cyberbullying, privacy, etc...",
lines=2
)
# Chat display
chatbot = gr.Chatbot(
label="Conversation",
height=400,
avatar_images=(None, "πŸ€–")
)
# Control buttons
with gr.Row():
submit_btn = gr.Button("πŸš€ Ask Question", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
with gr.Column(scale=1):
# βœ… Feedback & Training Panel
gr.HTML("<h3>🎯 Feedback & Learning</h3>")
# Current conversation feedback
with gr.Group():
feedback_radio = gr.Radio(
choices=["😊 Like", "πŸ‘Ž Dislike"],
label="Rate the last response",
info="Your feedback helps improve the AI",
visible=False
)
feedback_btn = gr.Button("πŸ’Ύ Submit Feedback", variant="secondary", visible=False)
feedback_status = gr.Textbox(
label="Feedback Status",
interactive=False,
max_lines=2
)
# Training statistics
gr.HTML("<h4>πŸ“Š Training Statistics</h4>")
stats_display = gr.Markdown("Loading statistics...")
# Refresh button (replaces automatic refresh)
refresh_btn = gr.Button("πŸ”„ Refresh Stats", variant="secondary", size="sm")
# Advanced controls
with gr.Accordion("πŸ”§ Advanced Controls", open=False):
manual_train_btn = gr.Button("πŸš€ Manual Training", variant="secondary")
export_btn = gr.Button("πŸ“€ Export Data", variant="secondary")
training_status = gr.Textbox(
label="Training Status",
interactive=False,
max_lines=3
)
# βœ… Event Handlers
# Main conversation flow
submit_btn.click(
generate_response,
inputs=[audio_input, text_input, age_input, chatbot],
outputs=[chatbot, text_input, audio_input, feedback_radio, current_query, current_response, current_age],
show_progress=True
).then(
lambda: (gr.update(visible=True), gr.update(visible=True)), # Show feedback components
outputs=[feedback_radio, feedback_btn]
)
# Clear chat
clear_btn.click(
clear_chat,
outputs=[chatbot, text_input, audio_input, feedback_radio, current_query, current_response, current_age]
).then(
lambda: (gr.update(visible=False), gr.update(visible=False)), # Hide feedback components
outputs=[feedback_radio, feedback_btn]
)
# Feedback submission
feedback_btn.click(
save_feedback_with_rlhf,
inputs=[current_query, current_response, feedback_radio, current_age],
outputs=[feedback_status, gr.State(), feedback_radio],
show_progress=True
).then(
refresh_stats,
outputs=[stats_display]
).then(
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(value=None)), # Hide and reset feedback
outputs=[feedback_radio, feedback_btn, feedback_radio]
)
# Manual training
manual_train_btn.click(
trigger_manual_training,
outputs=[training_status, gr.State()],
show_progress=True
).then(
refresh_stats,
outputs=[stats_display]
)
# Export data
export_btn.click(
export_training_data,
outputs=[training_status]
)
# Manual refresh stats
refresh_btn.click(
refresh_stats,
outputs=[stats_display]
)
# Initialize stats display on load
demo.load(
refresh_stats,
outputs=[stats_display]
)
# βœ… Launch Configuration
if __name__ == "__main__":
print("πŸš€ Launching Child Safety Chatbot with Voice & RLHF...")
print("πŸ“‹ Features enabled:")
print(f" {'βœ…' if SPEECH_AVAILABLE else '❌'} Voice input with speech-to-text")
print(" βœ… Ultra-lightweight RLHF with PPO-style training")
print(" βœ… Scikit-learn based reward model")
print(" βœ… Memory-optimized for free Colab")
print(" βœ… Automatic model improvement from feedback")
print(" βœ… Manual training controls")
print(" βœ… Training data export")
print(" βœ… Real-time statistics tracking")
print(" βœ… FIXED: No problematic 'every' parameter")
# Show current training stats
current_stats = rlhf_trainer.get_training_stats()
print(f"\nπŸ“Š Current training stats:")
print(f" - Total feedback: {current_stats['total_feedback']}")
print(f" - Training rounds: {current_stats['training_rounds']}")
print(f" - Ready for training: {current_stats['ready_for_training']}")
print(f" - New feedback needed: {current_stats['next_training_needs']}")
# Launch with appropriate settings
demo.launch(
server_name="0.0.0.0", # For Colab
server_port=7860,
share=True, # Creates public link
debug=False,
show_error=True,
quiet=False
)