cc / app.py
admin08077's picture
Update app.py
2666125 verified
import gradio as gr
from huggingface_hub import InferenceClient
import json
import os
import shutil
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
import joblib
import logging
# ---------------------------
# Logging Configuration
# ---------------------------
logging.basicConfig(
filename='app.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
# ---------------------------
# Initialize the HuggingFace API Client
# ---------------------------
# Replace 'gpt-3.5-turbo' with your desired model. Ensure you have the correct access.
try:
client = InferenceClient("gpt-3.5-turbo")
logger.info("HuggingFace InferenceClient initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace InferenceClient: {e}")
raise
# ---------------------------
# Persistent Memory and Knowledge Base Setup
# ---------------------------
memory_file = "chat_memory.json"
knowledge_base_dir = "knowledge_base"
model_file = "chat_model.pkl"
# Ensure directories exist
os.makedirs(knowledge_base_dir, exist_ok=True)
# ---------------------------
# Memory Management Functions
# ---------------------------
def load_memory():
"""Load conversation memory from a JSON file."""
try:
if os.path.exists(memory_file):
with open(memory_file, "r") as f:
memory = json.load(f)
logger.info("Conversation memory loaded successfully.")
return memory
logger.info("No existing conversation memory found. Starting fresh.")
return []
except Exception as e:
logger.error(f"Error loading memory: {e}")
return []
def save_memory(memory):
"""Save conversation memory to a JSON file."""
try:
with open(memory_file, "w") as f:
json.dump(memory, f, indent=2)
logger.info("Conversation memory saved successfully.")
except Exception as e:
logger.error(f"Error saving memory: {e}")
def update_memory(message, response):
"""Append user message and assistant response to memory."""
try:
memory = load_memory()
memory.append({"role": "user", "content": message})
memory.append({"role": "assistant", "content": response})
# Optionally limit memory size
if len(memory) > 1000:
memory = memory[-1000:]
save_memory(memory)
except Exception as e:
logger.error(f"Error updating memory: {e}")
# ---------------------------
# ML Model Management Functions
# ---------------------------
def load_or_initialize_model():
"""Load the ML model from a file or initialize a new one."""
try:
if os.path.exists(model_file):
model = joblib.load(model_file)
logger.info("ML model loaded successfully.")
return model
model = Pipeline([
("vectorizer", CountVectorizer()),
("classifier", RandomForestClassifier(n_estimators=100, random_state=42))
])
logger.info("Initialized new ML model pipeline.")
return model
except Exception as e:
logger.error(f"Error loading or initializing model: {e}")
raise
def train_model_on_files():
"""Train the ML model based on CSV files in the knowledge base."""
try:
model = load_or_initialize_model()
texts, labels = [], []
# Load data from the knowledge base
for file_name in os.listdir(knowledge_base_dir):
file_path = os.path.join(knowledge_base_dir, file_name)
if file_path.endswith(".csv"):
try:
df = pd.read_csv(file_path)
if "text" in df.columns and "label" in df.columns:
texts.extend(df["text"].astype(str).tolist())
labels.extend(df["label"].astype(str).tolist())
logger.info(f"Loaded data from '{file_name}'.")
else:
logger.warning(f"File '{file_name}' is missing 'text' or 'label' columns.")
return f"File '{file_name}' does not contain required 'text' and 'label' columns."
except Exception as e:
logger.error(f"Error reading '{file_name}': {e}")
return f"Error reading '{file_name}': {str(e)}"
if texts and labels:
try:
model.fit(texts, labels)
joblib.dump(model, model_file)
logger.info("ML model trained and saved successfully.")
return f"Model trained on {len(texts)} samples from {len(os.listdir(knowledge_base_dir))} files."
except Exception as e:
logger.error(f"Error during model training: {e}")
return f"Error during model training: {str(e)}"
logger.warning("No valid training data found in the knowledge base.")
return "No valid training data found in the knowledge base."
except Exception as e:
logger.error(f"Unexpected error in training model: {e}")
return f"Unexpected error: {str(e)}"
# ---------------------------
# Chat Response Function
# ---------------------------
def respond(message, history, system_message, max_tokens, temperature, top_p):
"""
Generate a response to the user's message using the ML model or GPT model.
Parameters:
- message (str): User's input message.
- history (list): Conversation history.
- system_message (str): System prompt.
- max_tokens (int): Maximum number of tokens for GPT response.
- temperature (float): Sampling temperature for GPT.
- top_p (float): Nucleus sampling parameter for GPT.
Returns:
- response (str): Generated response.
"""
try:
# Attempt to get a prediction from the ML model
model = load_or_initialize_model()
pred_label = model.predict([message])[0]
response = f"Predicted response: {pred_label}"
update_memory(message, response)
logger.info("Response generated using ML model.")
return response
except Exception as e:
logger.info("ML model could not generate a response. Falling back to GPT model.")
# Generate response using GPT
try:
messages = [{"role": "system", "content": system_message}]
for turn in history:
messages.append({"role": turn["role"], "content": turn["content"]})
messages.append({"role": "user", "content": message})
response = ""
for message_part in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message_part.get("choices", [{}])[0].get("delta", {}).get("content", "")
response += token
update_memory(message, response)
logger.info("Response generated using GPT model.")
return response
except Exception as e:
logger.error(f"Error generating response with GPT: {e}")
response = f"Error generating response: {str(e)}"
update_memory(message, response)
return response
# ---------------------------
# Gradio Interface
# ---------------------------
def create_gradio_interface():
"""Create and configure the Gradio interface."""
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Advanced AI Chatbot with Knowledge Base and Model Training")
# Chat Tab
with gr.Tab("πŸ’¬ Chat"):
chatbot = gr.Chatbot(label="AI Chatbot", type="messages")
with gr.Row():
with gr.Column(scale=5):
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=1
)
with gr.Column(scale=1, min_width=100):
send_button = gr.Button("Send", variant="primary")
with gr.Row():
system_message = gr.Textbox(
value="You are an advanced AI Chatbot.",
label="System Message",
visible=False
)
max_tokens = gr.Slider(
minimum=100, maximum=2048, value=512, step=100, label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (Nucleus Sampling)",
)
def handle_message(message, history, system_message, max_tokens, temperature, top_p):
response = respond(message, history, system_message, max_tokens, temperature, top_p)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history, history
send_button.click(
handle_message,
inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p],
outputs=[chatbot, chatbot],
)
user_input.submit(
handle_message,
inputs=[user_input, chatbot, system_message, max_tokens, temperature, top_p],
outputs=[chatbot, chatbot],
)
# Knowledge Base Tab
with gr.Tab("πŸ“š Knowledge Base"):
gr.Markdown("### Manage Knowledge Base")
file_upload = gr.File(
label="Upload CSV File",
file_types=[".csv"],
file_count="single" # Allows only single file upload
)
upload_output = gr.Textbox(label="Upload Result", interactive=False)
train_button = gr.Button("πŸ”„ Train Model on Knowledge Base")
train_output = gr.Textbox(label="Training Result", interactive=False)
def upload_file(file):
if not file:
return "No file uploaded."
try:
# Determine file path and name
if isinstance(file, dict):
file_path = file.get('path', '')
file_name = file.get('name', '')
else:
file_path = file
file_name = os.path.basename(file_path)
# Validate file extension
if not file_name.endswith(".csv"):
logger.warning(f"Invalid file type attempted: {file_name}")
return "Invalid file type. Please upload a CSV file."
# Save file to knowledge base directory
destination_path = os.path.join(knowledge_base_dir, file_name)
shutil.copy(file_path, destination_path)
logger.info(f"File '{file_name}' uploaded successfully.")
return f"File '{file_name}' uploaded successfully."
except Exception as e:
logger.error(f"Error uploading file: {e}")
return f"Error uploading file: {str(e)}"
file_upload.change(upload_file, inputs=file_upload, outputs=upload_output)
train_button.click(train_model_on_files, inputs=None, outputs=train_output)
# Memory Tab
with gr.Tab("🧠 Memory"):
gr.Markdown("### View and Manage Conversation Memory")
memory_display = gr.JSON(label="Conversation Memory")
with gr.Row():
refresh_memory = gr.Button("πŸ”„ Refresh Memory")
clear_memory = gr.Button("πŸ—‘οΈ Clear Memory")
export_memory = gr.Button("πŸ“€ Export Memory")
export_output = gr.File(label="Download Memory", visible=False)
def display_memory():
return load_memory()
def clear_memory_func():
try:
save_memory([])
logger.info("Conversation memory cleared.")
return []
except Exception as e:
logger.error(f"Error clearing memory: {e}")
return f"Error clearing memory: {str(e)}"
def export_memory_func():
if os.path.exists(memory_file):
return memory_file # Gradio will handle the download
return "No memory file found."
refresh_memory.click(display_memory, inputs=None, outputs=memory_display)
clear_memory.click(clear_memory_func, inputs=None, outputs=memory_display)
export_memory.click(export_memory_func, inputs=None, outputs=export_output)
# Download Model Tab
with gr.Tab("πŸ’Ύ Download Model"):
gr.Markdown("### Download the Trained Model")
download_button = gr.Button("πŸ“₯ Download Model")
model_download_output = gr.File(label="Downloadable Model")
def download_model():
if os.path.exists(model_file):
return model_file # Gradio will handle the file download
return "No trained model found."
download_button.click(download_model, inputs=None, outputs=model_download_output)
# Settings Tab
with gr.Tab("βš™οΈ Settings"):
gr.Markdown("### Application Settings")
gr.Textbox(
value="",
label="Settings Placeholder",
placeholder="Add settings here..."
# Removed 'interactive' parameter as it's unsupported
)
return demo
# ---------------------------
# Main Execution
# ---------------------------
if __name__ == "__main__":
try:
interface = create_gradio_interface()
logger.info("Launching Gradio interface.")
interface.launch()
except Exception as e:
logger.critical(f"Application failed to start: {e}")