AfVaCL / mainapp.py
tugaa's picture
Create mainapp.py
cf4231d verified
import os
import datetime
import uuid
import time
import threading
import traceback
import logging
from queue import Queue # Redisに置き換えるので不要になる
from dotenv import load_dotenv
import json
# --- Configuration ---
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
POSTGRES_DSN = os.getenv("POSTGRES_DSN", "postgresql://user:password@localhost:5432/agentdb")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
BASE_MODEL_NAME = os.getenv("BASE_MODEL_NAME", "gpt-4o-mini") # Fine-tuning base
# Fine-tuning するならローカルのOSSモデルが良い場合が多い
# BASE_MODEL_NAME = "meta-llama/Llama-3-8B-Instruct"
LEARNING_INTERVAL_HOURS = int(os.getenv("LEARNING_INTERVAL_HOURS", "6"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # For PyTorch/TRL
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Library Imports ---
# (上記 requirements.txt に対応するライブラリを import)
# LangChain components (as before)
from langchain_openai import ChatOpenAI, OpenAIEmbeddings # EmbeddingsはHuggingFace製が良いかも
from langchain.agents import AgentExecutor, create_react_agent, Tool
# ... other langchain imports
# Database (SQLAlchemy example)
from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, DateTime, Text, MetaData, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB # Use BYTEA or pgvector extension for vectors
# from sqlalchemy.dialects.postgresql import BYTEA # For raw byte vectors
# from pgvector.sqlalchemy import Vector # If using pgvector extension
from sqlalchemy.orm import sessionmaker, declarative_base
import sqlalchemy # Ensure it's imported
# Message Queue
import redis
# Vectorization
from sentence_transformers import SentenceTransformer
# Scheduling
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger
# TRL (Placeholders for actual imports and usage)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import LengthSampler
# --- Database Setup (SQLAlchemy) ---
Base = declarative_base()
engine = create_engine(POSTGRES_DSN)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Example Experience Table (Needs pgvector extension or BYTEA for vectors)
class Experience(Base):
__tablename__ = "experiences"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
timestamp = Column(DateTime, default=datetime.datetime.utcnow)
goal = Column(Text)
task = Column(Text)
# thought_summary = Column(Text) # Storing full thoughts can be large
action_info = Column(JSONB) # Store action, input, tool used etc.
observation_summary = Column(Text) # Summarize or store key parts
success = Column(Boolean)
feedback_score = Column(Float, default=0.0) # Numerical feedback
execution_time = Column(Float)
# --- Vector Representations ---
# Option 1: Use pgvector extension (Recommended)
# task_vector = Column(Vector(384)) # Example dimension for all-MiniLM-L6-v2
# observation_vector = Column(Vector(384))
# state_vector = Column(Vector(768)) # Example combined vector
# __table_args__ = (Index('ix_experiences_state_vector', state_vector, postgresql_using='hnsw', postgresql_with={'m': 16, 'ef_construction': 64}),)
# Option 2: Use BYTEA (Requires manual handling of bytes)
# task_vector_bytes = Column(BYTEA)
# observation_vector_bytes = Column(BYTEA)
# Example Task Table
class Task(Base):
__tablename__ = "tasks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
goal = Column(Text)
task_description = Column(Text)
status = Column(String, default="pending") # pending, processing, completed, failed
created_at = Column(DateTime, default=datetime.datetime.utcnow)
updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow)
result = Column(Text, nullable=True)
# Create tables if they don't exist
Base.metadata.create_all(bind=engine)
# --- Message Queue Setup (Redis) ---
redis_client = redis.from_url(REDIS_URL, decode_responses=True)
TASK_QUEUE_KEY = "agent_task_queue"
# --- Vectorization Model ---
# Use a sentence transformer model suitable for tasks/observations
# Consider models optimized for semantic similarity.
# Run this on CPU or GPU depending on availability/need.
embedding_model_name = 'all-MiniLM-L6-v2' # Example model
logging.info(f"Loading sentence transformer model: {embedding_model_name}...")
# Specify device to control CPU/GPU usage for embeddings
sentence_model = SentenceTransformer(embedding_model_name, device='cpu') # Use CPU for potentially less conflict with TRL on GPU
logging.info("Sentence transformer model loaded.")
def get_vector(text: str):
"""Generates a vector embedding for the given text."""
if not text:
return None
# Ensure model is on the correct device if moved
# sentence_model.to('cpu')
vector = sentence_model.encode(text, convert_to_numpy=True)
# If using BYTEA: return vector.tobytes()
# If using pgvector: return vector.tolist() # Or directly numpy array if supported
return vector.tolist() # For pgvector
# --- Experience Management (using DB) ---
def add_experience_db(task_info: dict, agent_output: dict, success: bool, feedback: float = 0.0, exec_time: float = 0.0):
"""Adds an agent's experience to the PostgreSQL database."""
db = SessionLocal()
try:
# --- Generate Vector Representations ---
task_vector = get_vector(task_info.get("task"))
obs_summary = agent_output.get("output", "")[:500] # Limit observation size
observation_vector = get_vector(obs_summary)
# Combine vectors or create a more complex state representation
state_vector = None
if task_vector and observation_vector:
# Simple concatenation example (ensure dimensions match DB schema)
# state_vector = task_vector + observation_vector
pass # Implement actual state vector logic
action_info = {
"action": agent_output.get("action", "unknown"), # Extract action if available
"input": agent_output.get("action_input", "unknown"), # Extract input if available
# Add other relevant details like tool used
}
exp = Experience(
goal=task_info.get("goal"),
task=task_info.get("task"),
action_info=action_info,
observation_summary=obs_summary,
success=success,
feedback_score=feedback,
execution_time=exec_time,
# task_vector=task_vector, # Assign vectors (match DB column type)
# observation_vector=observation_vector,
# state_vector=state_vector,
)
db.add(exp)
db.commit()
logging.debug(f"Experience added to DB: Success={success}, Task={task_info.get('task')[:50]}")
except Exception as e:
db.rollback()
logging.error(f"Failed to add experience to DB: {e}", exc_info=True)
finally:
db.close()
def retrieve_relevant_experiences_db(query: str, k: int = 3) -> list[Experience]:
"""Retrieves relevant experiences using vector similarity search (requires pgvector)."""
db = SessionLocal()
try:
query_vector = get_vector(query)
if query_vector is None:
return []
# --- Requires pgvector setup ---
# This query syntax depends on sqlalchemy-pgvector or raw SQL
# results = db.query(Experience).order_by(Experience.state_vector.l2_distance(query_vector)).limit(k).all()
# logging.info(f"Retrieved {len(results)} experiences from DB for query: {query[:50]}")
# return results
# --- Placeholder if pgvector is not set up ---
logging.warning("Vector search in DB requested but not implemented (requires pgvector). Returning empty list.")
return []
except Exception as e:
logging.error(f"Failed to retrieve experiences from DB: {e}", exc_info=True)
return []
finally:
db.close()
# --- Tools Definition (same as before) ---
# ... search, python_repl ...
tools = [
Tool(name="Search", func=search.run, description="..."),
Tool(name="PythonREPL", func=python_repl.run, description="..."),
]
# --- Agent Setup ---
# Use the base model for the agent initially. The fine-tuned model will be loaded by the learning worker.
agent_llm = ChatOpenAI(model=BASE_MODEL_NAME, temperature=0.3, api_key=OPENAI_API_KEY)
prompt_template = hub.pull("hwchase17/react-chat")
agent = create_react_agent(agent_llm, tools, prompt_template)
agent_executor = AgentExecutor(
agent=agent, tools=tools, verbose=False, handle_parsing_errors=True, max_iterations=10,
)
# --- Learning Module (TRL Implementation Sketch) ---
learning_lock = threading.Lock()
ppo_trainer = None # Global PPO trainer instance (or manage per learning cycle)
fine_tuned_model_path = "./fine_tuned_model" # Path to save/load fine-tuned adapter/model
def calculate_reward(experience_data: dict) -> float:
"""Calculates a reward score based on experience."""
reward = 0.0
if experience_data.get("success"):
reward += 1.0
else:
reward -= 1.0 # Penalty for failure
# Penalty for long execution time (log scale to moderate impact)
exec_time = experience_data.get("execution_time", 1.0) # Avoid log(0)
if exec_time > 1.0:
reward -= 0.1 * min(max(0, exec_time), 300)**0.5 # Capped sqrt penalty
# Incorporate feedback score
reward += experience_data.get("feedback_score", 0.0) * 0.5 # Scale feedback impact
return reward
def prepare_ppo_data(experiences: list[Experience]) -> list[dict]:
"""Prepares data in the format expected by TRL's PPOTrainer."""
ppo_data = []
for exp in experiences:
# Construct the 'query' - the input to the LLM for the task
query_text = f"Goal: {exp.goal}\nTask: {exp.task}"
# Construct the 'response' - the LLM's actual output (observation)
response_text = exp.observation_summary
# Calculate reward
reward_score = calculate_reward(exp.metadata) # Assuming metadata is attached or retrieved
if query_text and response_text:
ppo_data.append({
"query": query_text,
"response": response_text,
"reward": torch.tensor([reward_score], dtype=torch.float3_tensors) # TRL expects tensor
})
return ppo_data
def run_learning_cycle():
"""The main learning process using TRL."""
global ppo_trainer # Allow modification
if not torch.cuda.is_available():
logging.warning("CUDA not available. Skipping fine-tuning cycle.")
return
with learning_lock:
logging.info(f"[Learning Cycle Triggered] - Device: {DEVICE}")
start_time = time.time()
# 1. Fetch Data from PostgreSQL
logging.info("Fetching recent experiences from PostgreSQL...")
db = SessionLocal()
try:
# Fetch experiences (e.g., last N or within a time window)
recent_experiences = db.query(Experience).order_by(Experience.timestamp.desc()).limit(500).all() # Adjust limit
finally:
db.close()
if not recent_experiences or len(recent_experiences) < 50: # Need sufficient data
logging.info(f"Not enough new experiences ({len(recent_experiences)}). Skipping fine-tuning.")
return
logging.info(f"Fetched {len(recent_experiences)} experiences for learning.")
# 2. Prepare Data and Calculate Rewards
logging.info("Preparing data for PPO...")
ppo_data = prepare_ppo_data(recent_experiences)
if not ppo_data:
logging.warning("No valid data points after preparation. Skipping fine-tuning.")
return
# Convert to TRL dataset format (example, check TRL docs for specifics)
# This usually involves tokenizing queries and responses
# query_tensors = [tokenizer.encode(d['query'], return_tensors="pt").squeeze(0) for d in ppo_data]
# response_tensors = [tokenizer.encode(d['response'], return_tensors="pt").squeeze(0) for d in ppo_data]
# rewards = [d['reward'] for d in ppo_data]
# 3. Setup TRL PPO Trainer (Simplified Example)
logging.info("Setting up TRL PPOTrainer...")
try:
# --- TRL Configuration ---
ppo_config = PPOConfig(
model_name=BASE_MODEL_NAME,
learning_rate=1.41e-5,
batch_size=16, # Adjust based on GPU memory
mini_batch_size=4, # Adjust based on GPU memory
gradient_accumulation_steps=1,
optimize_cuda_cache=True,
# early_stopping=True,
# target_kl=0.1,
ppo_epochs=4, # Number of epochs per PPO step
seed=42,
# Use LoRA for efficient fine-tuning
use_lora=True,
)
# --- Model Loading (with Quantization and LoRA) ---
# bnb_config = BitsAndBytesConfig(...) # Optional quantization
lora_config = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token # Important for padding
# Load the base model with ValueHead for PPO and LoRA config
model = AutoModelForCausalLMWithValueHead.from_pretrained(
ppo_config.model_name,
# quantization_config=bnb_config, # Optional
peft_config=lora_config,
# load_in_8bit=True, # Or load_in_4bit=True
torch_dtype=torch.float16, # Use float16/bfloat16 on GPU
device_map="auto" # Use Accelerate for device mapping
)
# Reference model for KL divergence
ref_model = create_reference_model(model) # Or load separately
# --- Initialize Trainer ---
# Requires tokenized queries, responses, and rewards
# ppo_trainer = PPOTrainer(
# config=ppo_config,
# model=model,
# ref_model=ref_model,
# tokenizer=tokenizer,
# dataset=your_prepared_dataset, # Requires tokenized data
# data_collator=your_data_collator # Handles padding
# )
# --- PPO Training Loop ---
logging.info("Starting PPO Training Loop (Simulation - Actual requires dataset)...")
# for epoch in range(ppo_config.ppo_epochs):
# for batch in ppo_trainer.dataloader:
# # Get query tensors, response tensors from batch
# # Compute log probs, values, etc.
# # stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
# # ppo_trainer.log_stats(stats, batch, rewards)
# # Save model checkpoint periodically?
time.sleep(10) # Simulate training time
# --- Save Fine-tuned Model (LoRA Adapters) ---
logging.info("Saving fine-tuned LoRA adapters...")
# ppo_trainer.save_pretrained(fine_tuned_model_path)
# tokenizer.save_pretrained(fine_tuned_model_path)
logging.info(f"Fine-tuned adapters saved to {fine_tuned_model_path}")
except Exception as e:
logging.error(f"Error during TRL setup or training: {e}", exc_info=True)
# Clean up GPU memory if needed
del model, ref_model, ppo_trainer
torch.cuda.empty_cache()
logging.info(f"Learning cycle finished. Duration: {time.time() - start_time:.2f}s")
# --- Task Management (using Redis) ---
def add_task_mq(task: str, goal: str):
"""Adds a task to the Redis queue."""
task_id = str(uuid.uuid4())
task_data = json.dumps({"id": task_id, "task": task, "goal": goal})
try:
redis_client.lpush(TASK_QUEUE_KEY, task_data)
logging.info(f"Task {task_id} added to Redis queue: {task[:50]}...")
except Exception as e:
logging.error(f"Failed to add task to Redis: {e}")
# --- Agent Worker (modified for Redis and DB) ---
def agent_worker(worker_id: int):
"""Processes tasks from the Redis queue."""
logging.info(f"Agent Worker-{worker_id} started.")
while True: # Run continuously
try:
# Blocking pop from Redis list (wait indefinitely)
_, task_data_json = redis_client.brpop(TASK_QUEUE_KEY)
task_info = json.loads(task_data_json)
task_id = task_info["id"]
task_desc = task_info["task"]
goal = task_info["goal"]
logging.info(f"Worker-{worker_id} processing Task {task_id}: {task_desc[:50]}...")
start_time = time.time()
success = False
final_output = None
agent_result = {} # Store agent's output details
# Update task status in DB (optional)
# update_task_status(task_id, "processing")
# --- Retrieve relevant experiences ---
# query = f"Goal: {goal}\nTask: {task_desc}"
# relevant_experiences = retrieve_relevant_experiences_db(query, k=3)
# experience_context = ... # Format context from DB results
# --- Prepare Agent Input ---
input_messages = [
SystemMessage(content=f"Your long term goal is: {goal}. Think step-by-step."),
# Add experience_context here if needed
HumanMessage(content=f"Current task: {task_desc}")
]
# --- Execute Agent ---
try:
# Ideally, load the latest fine-tuned model for inference here
# This requires coordination or loading the adapter weights
agent_result = agent_executor.invoke({"input": input_messages})
final_output = agent_result.get("output", "No output.")
# Simple success check (refine this based on tool usage, keywords etc.)
success = not any(err in final_output.lower() for err in ["error", "fail", "unable"])
except Exception as e:
logging.error(f"Worker-{worker_id} Task {task_id} failed during execution: {e}", exc_info=True)
final_output = f"Agent execution failed: {e}"
success = False
agent_result = {"output": final_output, "action": "error"} # Log error state
# --- Record Experience ---
exec_time = time.time() - start_time
# Add user feedback later if available (e.g., via API)
feedback_score = 0.0
add_experience_db(task_info, agent_result, success, feedback_score, exec_time)
# Update task status in DB (optional)
# update_task_status(task_id, "completed" if success else "failed", final_output)
logging.info(f"Worker-{worker_id} finished Task {task_id}. Success: {success}. Time: {exec_time:.2f}s")
except redis.exceptions.ConnectionError as e:
logging.error(f"Worker-{worker_id} Redis connection error: {e}. Retrying in 10s...")
time.sleep(10)
except Exception as e:
logging.error(f"Worker-{worker_id} encountered an unexpected error: {e}", exc_info=True)
time.sleep(5) # Avoid rapid looping on persistent errors
# --- Main Execution / Service Startup ---
if __name__ == "__main__":
logging.info("Initializing Agent System...")
# --- Start Background Learning Scheduler ---
scheduler = BackgroundScheduler(daemon=True)
scheduler.add_job(
run_learning_cycle,
trigger=IntervalTrigger(hours=LEARNING_INTERVAL_HOURS),
id="learning_job",
name="Fine-tuning Learning Cycle",
replace_existing=True
)
scheduler.start()
logging.info(f"Background learning scheduler started. Interval: {LEARNING_INTERVAL_HOURS} hours.")
# --- Start Agent Workers ---
num_workers = int(os.getenv("NUM_WORKERS", "2"))
worker_threads = []
for i in range(num_workers):
thread = threading.Thread(target=agent_worker, args=(i+1,), daemon=True)
thread.start()
worker_threads.append(thread)
logging.info(f"{num_workers} Agent worker threads started.")
# --- Add Initial Tasks (Example) ---
add_task_mq("Explain the difference between LoRA and full fine-tuning for LLMs.",
"Understand AI model optimization techniques.")
add_task_mq("Write a Python script using pandas to read a CSV file named 'data.csv' and print the first 5 rows.",
"Develop data processing scripts.")
logging.info("Agent system is running. Workers processing tasks from Redis.")
logging.info("Press Ctrl+C to stop.")
try:
# Keep main thread alive
while True:
time.sleep(60)
# Add health checks or monitoring here if needed
logging.debug("Main thread alive...")
except KeyboardInterrupt:
logging.info("Shutdown signal received...")
scheduler.shutdown()
# Workers are daemon threads, they will exit when main thread exits.
# Implement graceful shutdown for workers if needed (e.g., sending sentinel)
logging.info("Agent system stopped.")