tugaa commited on
Commit
cf4231d
·
verified ·
1 Parent(s): 5dd8411

Create mainapp.py

Browse files
Files changed (1) hide show
  1. mainapp.py +487 -0
mainapp.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import uuid
4
+ import time
5
+ import threading
6
+ import traceback
7
+ import logging
8
+ from queue import Queue # Redisに置き換えるので不要になる
9
+ from dotenv import load_dotenv
10
+ import json
11
+
12
+ # --- Configuration ---
13
+ load_dotenv()
14
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
15
+ POSTGRES_DSN = os.getenv("POSTGRES_DSN", "postgresql://user:password@localhost:5432/agentdb")
16
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
17
+ BASE_MODEL_NAME = os.getenv("BASE_MODEL_NAME", "gpt-4o-mini") # Fine-tuning base
18
+ # Fine-tuning するならローカルのOSSモデルが良い場合が多い
19
+ # BASE_MODEL_NAME = "meta-llama/Llama-3-8B-Instruct"
20
+ LEARNING_INTERVAL_HOURS = int(os.getenv("LEARNING_INTERVAL_HOURS", "6"))
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # For PyTorch/TRL
22
+
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+
25
+ # --- Library Imports ---
26
+ # (上記 requirements.txt に対応するライブラリを import)
27
+ # LangChain components (as before)
28
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings # EmbeddingsはHuggingFace製が良いかも
29
+ from langchain.agents import AgentExecutor, create_react_agent, Tool
30
+ # ... other langchain imports
31
+
32
+ # Database (SQLAlchemy example)
33
+ from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, DateTime, Text, MetaData, Index
34
+ from sqlalchemy.dialects.postgresql import UUID, JSONB # Use BYTEA or pgvector extension for vectors
35
+ # from sqlalchemy.dialects.postgresql import BYTEA # For raw byte vectors
36
+ # from pgvector.sqlalchemy import Vector # If using pgvector extension
37
+ from sqlalchemy.orm import sessionmaker, declarative_base
38
+ import sqlalchemy # Ensure it's imported
39
+
40
+ # Message Queue
41
+ import redis
42
+
43
+ # Vectorization
44
+ from sentence_transformers import SentenceTransformer
45
+
46
+ # Scheduling
47
+ from apscheduler.schedulers.background import BackgroundScheduler
48
+ from apscheduler.triggers.interval import IntervalTrigger
49
+
50
+ # TRL (Placeholders for actual imports and usage)
51
+ import torch
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
53
+ from peft import LoraConfig
54
+ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
55
+ from trl.core import LengthSampler
56
+
57
+ # --- Database Setup (SQLAlchemy) ---
58
+ Base = declarative_base()
59
+ engine = create_engine(POSTGRES_DSN)
60
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
61
+
62
+ # Example Experience Table (Needs pgvector extension or BYTEA for vectors)
63
+ class Experience(Base):
64
+ __tablename__ = "experiences"
65
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
66
+ timestamp = Column(DateTime, default=datetime.datetime.utcnow)
67
+ goal = Column(Text)
68
+ task = Column(Text)
69
+ # thought_summary = Column(Text) # Storing full thoughts can be large
70
+ action_info = Column(JSONB) # Store action, input, tool used etc.
71
+ observation_summary = Column(Text) # Summarize or store key parts
72
+ success = Column(Boolean)
73
+ feedback_score = Column(Float, default=0.0) # Numerical feedback
74
+ execution_time = Column(Float)
75
+ # --- Vector Representations ---
76
+ # Option 1: Use pgvector extension (Recommended)
77
+ # task_vector = Column(Vector(384)) # Example dimension for all-MiniLM-L6-v2
78
+ # observation_vector = Column(Vector(384))
79
+ # state_vector = Column(Vector(768)) # Example combined vector
80
+ # __table_args__ = (Index('ix_experiences_state_vector', state_vector, postgresql_using='hnsw', postgresql_with={'m': 16, 'ef_construction': 64}),)
81
+
82
+ # Option 2: Use BYTEA (Requires manual handling of bytes)
83
+ # task_vector_bytes = Column(BYTEA)
84
+ # observation_vector_bytes = Column(BYTEA)
85
+
86
+ # Example Task Table
87
+ class Task(Base):
88
+ __tablename__ = "tasks"
89
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
90
+ goal = Column(Text)
91
+ task_description = Column(Text)
92
+ status = Column(String, default="pending") # pending, processing, completed, failed
93
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
94
+ updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow)
95
+ result = Column(Text, nullable=True)
96
+
97
+ # Create tables if they don't exist
98
+ Base.metadata.create_all(bind=engine)
99
+
100
+ # --- Message Queue Setup (Redis) ---
101
+ redis_client = redis.from_url(REDIS_URL, decode_responses=True)
102
+ TASK_QUEUE_KEY = "agent_task_queue"
103
+
104
+ # --- Vectorization Model ---
105
+ # Use a sentence transformer model suitable for tasks/observations
106
+ # Consider models optimized for semantic similarity.
107
+ # Run this on CPU or GPU depending on availability/need.
108
+ embedding_model_name = 'all-MiniLM-L6-v2' # Example model
109
+ logging.info(f"Loading sentence transformer model: {embedding_model_name}...")
110
+ # Specify device to control CPU/GPU usage for embeddings
111
+ sentence_model = SentenceTransformer(embedding_model_name, device='cpu') # Use CPU for potentially less conflict with TRL on GPU
112
+ logging.info("Sentence transformer model loaded.")
113
+
114
+ def get_vector(text: str):
115
+ """Generates a vector embedding for the given text."""
116
+ if not text:
117
+ return None
118
+ # Ensure model is on the correct device if moved
119
+ # sentence_model.to('cpu')
120
+ vector = sentence_model.encode(text, convert_to_numpy=True)
121
+ # If using BYTEA: return vector.tobytes()
122
+ # If using pgvector: return vector.tolist() # Or directly numpy array if supported
123
+ return vector.tolist() # For pgvector
124
+
125
+ # --- Experience Management (using DB) ---
126
+ def add_experience_db(task_info: dict, agent_output: dict, success: bool, feedback: float = 0.0, exec_time: float = 0.0):
127
+ """Adds an agent's experience to the PostgreSQL database."""
128
+ db = SessionLocal()
129
+ try:
130
+ # --- Generate Vector Representations ---
131
+ task_vector = get_vector(task_info.get("task"))
132
+ obs_summary = agent_output.get("output", "")[:500] # Limit observation size
133
+ observation_vector = get_vector(obs_summary)
134
+ # Combine vectors or create a more complex state representation
135
+ state_vector = None
136
+ if task_vector and observation_vector:
137
+ # Simple concatenation example (ensure dimensions match DB schema)
138
+ # state_vector = task_vector + observation_vector
139
+ pass # Implement actual state vector logic
140
+
141
+ action_info = {
142
+ "action": agent_output.get("action", "unknown"), # Extract action if available
143
+ "input": agent_output.get("action_input", "unknown"), # Extract input if available
144
+ # Add other relevant details like tool used
145
+ }
146
+
147
+ exp = Experience(
148
+ goal=task_info.get("goal"),
149
+ task=task_info.get("task"),
150
+ action_info=action_info,
151
+ observation_summary=obs_summary,
152
+ success=success,
153
+ feedback_score=feedback,
154
+ execution_time=exec_time,
155
+ # task_vector=task_vector, # Assign vectors (match DB column type)
156
+ # observation_vector=observation_vector,
157
+ # state_vector=state_vector,
158
+ )
159
+ db.add(exp)
160
+ db.commit()
161
+ logging.debug(f"Experience added to DB: Success={success}, Task={task_info.get('task')[:50]}")
162
+ except Exception as e:
163
+ db.rollback()
164
+ logging.error(f"Failed to add experience to DB: {e}", exc_info=True)
165
+ finally:
166
+ db.close()
167
+
168
+ def retrieve_relevant_experiences_db(query: str, k: int = 3) -> list[Experience]:
169
+ """Retrieves relevant experiences using vector similarity search (requires pgvector)."""
170
+ db = SessionLocal()
171
+ try:
172
+ query_vector = get_vector(query)
173
+ if query_vector is None:
174
+ return []
175
+
176
+ # --- Requires pgvector setup ---
177
+ # This query syntax depends on sqlalchemy-pgvector or raw SQL
178
+ # results = db.query(Experience).order_by(Experience.state_vector.l2_distance(query_vector)).limit(k).all()
179
+ # logging.info(f"Retrieved {len(results)} experiences from DB for query: {query[:50]}")
180
+ # return results
181
+
182
+ # --- Placeholder if pgvector is not set up ---
183
+ logging.warning("Vector search in DB requested but not implemented (requires pgvector). Returning empty list.")
184
+ return []
185
+ except Exception as e:
186
+ logging.error(f"Failed to retrieve experiences from DB: {e}", exc_info=True)
187
+ return []
188
+ finally:
189
+ db.close()
190
+
191
+ # --- Tools Definition (same as before) ---
192
+ # ... search, python_repl ...
193
+ tools = [
194
+ Tool(name="Search", func=search.run, description="..."),
195
+ Tool(name="PythonREPL", func=python_repl.run, description="..."),
196
+ ]
197
+
198
+ # --- Agent Setup ---
199
+ # Use the base model for the agent initially. The fine-tuned model will be loaded by the learning worker.
200
+ agent_llm = ChatOpenAI(model=BASE_MODEL_NAME, temperature=0.3, api_key=OPENAI_API_KEY)
201
+ prompt_template = hub.pull("hwchase17/react-chat")
202
+ agent = create_react_agent(agent_llm, tools, prompt_template)
203
+ agent_executor = AgentExecutor(
204
+ agent=agent, tools=tools, verbose=False, handle_parsing_errors=True, max_iterations=10,
205
+ )
206
+
207
+
208
+ # --- Learning Module (TRL Implementation Sketch) ---
209
+ learning_lock = threading.Lock()
210
+ ppo_trainer = None # Global PPO trainer instance (or manage per learning cycle)
211
+ fine_tuned_model_path = "./fine_tuned_model" # Path to save/load fine-tuned adapter/model
212
+
213
+ def calculate_reward(experience_data: dict) -> float:
214
+ """Calculates a reward score based on experience."""
215
+ reward = 0.0
216
+ if experience_data.get("success"):
217
+ reward += 1.0
218
+ else:
219
+ reward -= 1.0 # Penalty for failure
220
+
221
+ # Penalty for long execution time (log scale to moderate impact)
222
+ exec_time = experience_data.get("execution_time", 1.0) # Avoid log(0)
223
+ if exec_time > 1.0:
224
+ reward -= 0.1 * min(max(0, exec_time), 300)**0.5 # Capped sqrt penalty
225
+
226
+ # Incorporate feedback score
227
+ reward += experience_data.get("feedback_score", 0.0) * 0.5 # Scale feedback impact
228
+
229
+ return reward
230
+
231
+ def prepare_ppo_data(experiences: list[Experience]) -> list[dict]:
232
+ """Prepares data in the format expected by TRL's PPOTrainer."""
233
+ ppo_data = []
234
+ for exp in experiences:
235
+ # Construct the 'query' - the input to the LLM for the task
236
+ query_text = f"Goal: {exp.goal}\nTask: {exp.task}"
237
+ # Construct the 'response' - the LLM's actual output (observation)
238
+ response_text = exp.observation_summary
239
+ # Calculate reward
240
+ reward_score = calculate_reward(exp.metadata) # Assuming metadata is attached or retrieved
241
+
242
+ if query_text and response_text:
243
+ ppo_data.append({
244
+ "query": query_text,
245
+ "response": response_text,
246
+ "reward": torch.tensor([reward_score], dtype=torch.float3_tensors) # TRL expects tensor
247
+ })
248
+ return ppo_data
249
+
250
+
251
+ def run_learning_cycle():
252
+ """The main learning process using TRL."""
253
+ global ppo_trainer # Allow modification
254
+ if not torch.cuda.is_available():
255
+ logging.warning("CUDA not available. Skipping fine-tuning cycle.")
256
+ return
257
+
258
+ with learning_lock:
259
+ logging.info(f"[Learning Cycle Triggered] - Device: {DEVICE}")
260
+ start_time = time.time()
261
+
262
+ # 1. Fetch Data from PostgreSQL
263
+ logging.info("Fetching recent experiences from PostgreSQL...")
264
+ db = SessionLocal()
265
+ try:
266
+ # Fetch experiences (e.g., last N or within a time window)
267
+ recent_experiences = db.query(Experience).order_by(Experience.timestamp.desc()).limit(500).all() # Adjust limit
268
+ finally:
269
+ db.close()
270
+
271
+ if not recent_experiences or len(recent_experiences) < 50: # Need sufficient data
272
+ logging.info(f"Not enough new experiences ({len(recent_experiences)}). Skipping fine-tuning.")
273
+ return
274
+ logging.info(f"Fetched {len(recent_experiences)} experiences for learning.")
275
+
276
+ # 2. Prepare Data and Calculate Rewards
277
+ logging.info("Preparing data for PPO...")
278
+ ppo_data = prepare_ppo_data(recent_experiences)
279
+ if not ppo_data:
280
+ logging.warning("No valid data points after preparation. Skipping fine-tuning.")
281
+ return
282
+
283
+ # Convert to TRL dataset format (example, check TRL docs for specifics)
284
+ # This usually involves tokenizing queries and responses
285
+ # query_tensors = [tokenizer.encode(d['query'], return_tensors="pt").squeeze(0) for d in ppo_data]
286
+ # response_tensors = [tokenizer.encode(d['response'], return_tensors="pt").squeeze(0) for d in ppo_data]
287
+ # rewards = [d['reward'] for d in ppo_data]
288
+
289
+ # 3. Setup TRL PPO Trainer (Simplified Example)
290
+ logging.info("Setting up TRL PPOTrainer...")
291
+ try:
292
+ # --- TRL Configuration ---
293
+ ppo_config = PPOConfig(
294
+ model_name=BASE_MODEL_NAME,
295
+ learning_rate=1.41e-5,
296
+ batch_size=16, # Adjust based on GPU memory
297
+ mini_batch_size=4, # Adjust based on GPU memory
298
+ gradient_accumulation_steps=1,
299
+ optimize_cuda_cache=True,
300
+ # early_stopping=True,
301
+ # target_kl=0.1,
302
+ ppo_epochs=4, # Number of epochs per PPO step
303
+ seed=42,
304
+ # Use LoRA for efficient fine-tuning
305
+ use_lora=True,
306
+ )
307
+
308
+ # --- Model Loading (with Quantization and LoRA) ---
309
+ # bnb_config = BitsAndBytesConfig(...) # Optional quantization
310
+ lora_config = LoraConfig(
311
+ r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
312
+ )
313
+ tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
314
+ if getattr(tokenizer, "pad_token", None) is None:
315
+ tokenizer.pad_token = tokenizer.eos_token # Important for padding
316
+
317
+ # Load the base model with ValueHead for PPO and LoRA config
318
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(
319
+ ppo_config.model_name,
320
+ # quantization_config=bnb_config, # Optional
321
+ peft_config=lora_config,
322
+ # load_in_8bit=True, # Or load_in_4bit=True
323
+ torch_dtype=torch.float16, # Use float16/bfloat16 on GPU
324
+ device_map="auto" # Use Accelerate for device mapping
325
+ )
326
+ # Reference model for KL divergence
327
+ ref_model = create_reference_model(model) # Or load separately
328
+
329
+ # --- Initialize Trainer ---
330
+ # Requires tokenized queries, responses, and rewards
331
+ # ppo_trainer = PPOTrainer(
332
+ # config=ppo_config,
333
+ # model=model,
334
+ # ref_model=ref_model,
335
+ # tokenizer=tokenizer,
336
+ # dataset=your_prepared_dataset, # Requires tokenized data
337
+ # data_collator=your_data_collator # Handles padding
338
+ # )
339
+
340
+ # --- PPO Training Loop ---
341
+ logging.info("Starting PPO Training Loop (Simulation - Actual requires dataset)...")
342
+ # for epoch in range(ppo_config.ppo_epochs):
343
+ # for batch in ppo_trainer.dataloader:
344
+ # # Get query tensors, response tensors from batch
345
+ # # Compute log probs, values, etc.
346
+ # # stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
347
+ # # ppo_trainer.log_stats(stats, batch, rewards)
348
+ # # Save model checkpoint periodically?
349
+ time.sleep(10) # Simulate training time
350
+
351
+ # --- Save Fine-tuned Model (LoRA Adapters) ---
352
+ logging.info("Saving fine-tuned LoRA adapters...")
353
+ # ppo_trainer.save_pretrained(fine_tuned_model_path)
354
+ # tokenizer.save_pretrained(fine_tuned_model_path)
355
+ logging.info(f"Fine-tuned adapters saved to {fine_tuned_model_path}")
356
+
357
+ except Exception as e:
358
+ logging.error(f"Error during TRL setup or training: {e}", exc_info=True)
359
+ # Clean up GPU memory if needed
360
+ del model, ref_model, ppo_trainer
361
+ torch.cuda.empty_cache()
362
+
363
+ logging.info(f"Learning cycle finished. Duration: {time.time() - start_time:.2f}s")
364
+
365
+ # --- Task Management (using Redis) ---
366
+ def add_task_mq(task: str, goal: str):
367
+ """Adds a task to the Redis queue."""
368
+ task_id = str(uuid.uuid4())
369
+ task_data = json.dumps({"id": task_id, "task": task, "goal": goal})
370
+ try:
371
+ redis_client.lpush(TASK_QUEUE_KEY, task_data)
372
+ logging.info(f"Task {task_id} added to Redis queue: {task[:50]}...")
373
+ except Exception as e:
374
+ logging.error(f"Failed to add task to Redis: {e}")
375
+
376
+ # --- Agent Worker (modified for Redis and DB) ---
377
+ def agent_worker(worker_id: int):
378
+ """Processes tasks from the Redis queue."""
379
+ logging.info(f"Agent Worker-{worker_id} started.")
380
+ while True: # Run continuously
381
+ try:
382
+ # Blocking pop from Redis list (wait indefinitely)
383
+ _, task_data_json = redis_client.brpop(TASK_QUEUE_KEY)
384
+ task_info = json.loads(task_data_json)
385
+ task_id = task_info["id"]
386
+ task_desc = task_info["task"]
387
+ goal = task_info["goal"]
388
+
389
+ logging.info(f"Worker-{worker_id} processing Task {task_id}: {task_desc[:50]}...")
390
+ start_time = time.time()
391
+ success = False
392
+ final_output = None
393
+ agent_result = {} # Store agent's output details
394
+
395
+ # Update task status in DB (optional)
396
+ # update_task_status(task_id, "processing")
397
+
398
+ # --- Retrieve relevant experiences ---
399
+ # query = f"Goal: {goal}\nTask: {task_desc}"
400
+ # relevant_experiences = retrieve_relevant_experiences_db(query, k=3)
401
+ # experience_context = ... # Format context from DB results
402
+
403
+ # --- Prepare Agent Input ---
404
+ input_messages = [
405
+ SystemMessage(content=f"Your long term goal is: {goal}. Think step-by-step."),
406
+ # Add experience_context here if needed
407
+ HumanMessage(content=f"Current task: {task_desc}")
408
+ ]
409
+
410
+ # --- Execute Agent ---
411
+ try:
412
+ # Ideally, load the latest fine-tuned model for inference here
413
+ # This requires coordination or loading the adapter weights
414
+ agent_result = agent_executor.invoke({"input": input_messages})
415
+ final_output = agent_result.get("output", "No output.")
416
+ # Simple success check (refine this based on tool usage, keywords etc.)
417
+ success = not any(err in final_output.lower() for err in ["error", "fail", "unable"])
418
+ except Exception as e:
419
+ logging.error(f"Worker-{worker_id} Task {task_id} failed during execution: {e}", exc_info=True)
420
+ final_output = f"Agent execution failed: {e}"
421
+ success = False
422
+ agent_result = {"output": final_output, "action": "error"} # Log error state
423
+
424
+ # --- Record Experience ---
425
+ exec_time = time.time() - start_time
426
+ # Add user feedback later if available (e.g., via API)
427
+ feedback_score = 0.0
428
+ add_experience_db(task_info, agent_result, success, feedback_score, exec_time)
429
+
430
+ # Update task status in DB (optional)
431
+ # update_task_status(task_id, "completed" if success else "failed", final_output)
432
+
433
+ logging.info(f"Worker-{worker_id} finished Task {task_id}. Success: {success}. Time: {exec_time:.2f}s")
434
+
435
+ except redis.exceptions.ConnectionError as e:
436
+ logging.error(f"Worker-{worker_id} Redis connection error: {e}. Retrying in 10s...")
437
+ time.sleep(10)
438
+ except Exception as e:
439
+ logging.error(f"Worker-{worker_id} encountered an unexpected error: {e}", exc_info=True)
440
+ time.sleep(5) # Avoid rapid looping on persistent errors
441
+
442
+ # --- Main Execution / Service Startup ---
443
+ if __name__ == "__main__":
444
+ logging.info("Initializing Agent System...")
445
+
446
+ # --- Start Background Learning Scheduler ---
447
+ scheduler = BackgroundScheduler(daemon=True)
448
+ scheduler.add_job(
449
+ run_learning_cycle,
450
+ trigger=IntervalTrigger(hours=LEARNING_INTERVAL_HOURS),
451
+ id="learning_job",
452
+ name="Fine-tuning Learning Cycle",
453
+ replace_existing=True
454
+ )
455
+ scheduler.start()
456
+ logging.info(f"Background learning scheduler started. Interval: {LEARNING_INTERVAL_HOURS} hours.")
457
+
458
+ # --- Start Agent Workers ---
459
+ num_workers = int(os.getenv("NUM_WORKERS", "2"))
460
+ worker_threads = []
461
+ for i in range(num_workers):
462
+ thread = threading.Thread(target=agent_worker, args=(i+1,), daemon=True)
463
+ thread.start()
464
+ worker_threads.append(thread)
465
+ logging.info(f"{num_workers} Agent worker threads started.")
466
+
467
+ # --- Add Initial Tasks (Example) ---
468
+ add_task_mq("Explain the difference between LoRA and full fine-tuning for LLMs.",
469
+ "Understand AI model optimization techniques.")
470
+ add_task_mq("Write a Python script using pandas to read a CSV file named 'data.csv' and print the first 5 rows.",
471
+ "Develop data processing scripts.")
472
+
473
+ logging.info("Agent system is running. Workers processing tasks from Redis.")
474
+ logging.info("Press Ctrl+C to stop.")
475
+
476
+ try:
477
+ # Keep main thread alive
478
+ while True:
479
+ time.sleep(60)
480
+ # Add health checks or monitoring here if needed
481
+ logging.debug("Main thread alive...")
482
+ except KeyboardInterrupt:
483
+ logging.info("Shutdown signal received...")
484
+ scheduler.shutdown()
485
+ # Workers are daemon threads, they will exit when main thread exits.
486
+ # Implement graceful shutdown for workers if needed (e.g., sending sentinel)
487
+ logging.info("Agent system stopped.")