Spaces:
Sleeping
Sleeping
| import re | |
| import time | |
| from litellm import completion, completion_cost | |
| from abc import ABC, abstractmethod | |
| from PyCharacterAI import get_client | |
| from PyCharacterAI.exceptions import SessionClosedError | |
| import asyncio | |
| import gc | |
| import torch | |
| from HumanSimulacra.multi_agent_cognitive_mechanism import Top_agent | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import logging | |
| class Agent(ABC): | |
| def __init__(self, model, name, description): | |
| self.name = name | |
| self.description = description | |
| self.history = [] | |
| self.model = model | |
| self.history.append({ | |
| "role": "system", | |
| "content": self.description | |
| }) # Initialize with system message | |
| self.max_retry = 3 | |
| self.cost = 0.0 | |
| self.history_with_reasoning = [{ | |
| "role": "system", | |
| "content": self.description | |
| }] | |
| def store_chat(self, role, message, reasoning=None): | |
| chat = { | |
| "role": role, | |
| "content": message | |
| } | |
| self.history.append(chat) | |
| if reasoning: | |
| chat_with_reasoning = { | |
| "role": role, | |
| "content": message, | |
| "reasoning_content": reasoning | |
| } | |
| self.history_with_reasoning.append(chat_with_reasoning) | |
| else: | |
| self.history_with_reasoning.append(chat.copy()) | |
| def chat(self, prompt): | |
| pass | |
| class PersonaAgent(Agent): | |
| def __init__(self, **kwargs): | |
| assert kwargs.get('baseline_name') in ["characterai", "human_simulacra", "opencharacter", "human_interview"], "Invalid baseline name" | |
| self.type = kwargs.get('baseline_name') | |
| if self.type == "characterai": | |
| assert 'character_id' in kwargs, "Character AI requires character_id parameter" | |
| assert 'user_id' in kwargs, "Character AI requires user_id parameter" | |
| super().__init__(model=None, name=kwargs.get('name'), description=None) | |
| self.char_id = kwargs['character_id'] | |
| self.user_id = kwargs['user_id'] | |
| try: | |
| asyncio.run(self._setup_client_and_chat(kwargs['user_id'], kwargs['character_id'])) | |
| assert self.chat_id is not None, "Chat ID must be set after setup" | |
| except SessionClosedError as e: | |
| logging.error(f"Session closed error: {e}") | |
| self.client_or_model = None | |
| self.chat_id = None | |
| super().__init__(model=None, name=kwargs.get('name'), description=None) # no model for characterai, it uses its own client | |
| elif self.type == "human_simulacra": | |
| assert 'name' in kwargs, "Human Simulacra requires name parameter" | |
| super().__init__(model=kwargs.get('model', "gpt-4.1-mini-2025-04-14"), name=kwargs.get('name'), description=None) # default model | |
| self.client_or_model = Top_agent(character_name=kwargs['name']) ### has its own chat history | |
| elif self.type == "opencharacter": # OpenCharacter | |
| assert 'model_path' in kwargs, "OpenCharacter requires (path-like, either huggingface repo OR local path) model parameter" | |
| assert 'persona' in kwargs, "OpenCharacter requires persona parameter" | |
| assert 'profile' in kwargs, "OpenCharacter requires profile parameter" | |
| self.client_or_model = AutoModelForCausalLM.from_pretrained( | |
| kwargs['model_path'], | |
| load_in_4bit=kwargs.get('load_in_4bit', False), # default is False | |
| # device_map={"":0} | |
| ).to("cuda").eval() | |
| self.tokenizer = AutoTokenizer.from_pretrained(kwargs['model_path']) | |
| self.history = [{ | |
| "role": "system", | |
| "content": ("You are an AI character with the following Persona.\n\n" | |
| f"# Persona\n{kwargs['persona']}\n\n" | |
| f'# Character Profile\n{kwargs["profile"]}\n\n' | |
| "Please stay in character, be helpful and harmless." | |
| ), | |
| }] | |
| self.name = kwargs.get('name', None) or re.search(r'^Name:\s*(.+)$', kwargs['profile'], flags=re.MULTILINE).group(1).strip() | |
| else: # human_interview | |
| self.name = kwargs.get('name', None) | |
| def chat(self, message:str): | |
| if self.type == "characterai": | |
| # Add a small delay before sending message | |
| time.sleep(0.5) # 500ms delay | |
| response = asyncio.run(self.client_or_model.chat.send_message(character_id=self.char_id, chat_id=self.chat_id, text=message)) | |
| return response.get_primary_candidate().text | |
| elif self.type == "human_simulacra": | |
| return self.client_or_model.send_message(message) | |
| elif self.type == "opencharacter": # OpenCharacter | |
| self.history.append({ | |
| "role": "user", | |
| "content": message | |
| }) | |
| while True: | |
| input_ids = self.tokenizer.apply_chat_template( | |
| self.history, | |
| tokenize=True, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to(self.client_or_model.device) | |
| if input_ids.shape[1] <= self.client_or_model.config.max_position_embeddings: | |
| break | |
| # drop oldest assistant-user pair but keep system prompt | |
| if len(self.history) > 3: | |
| self.history = [self.history[0]] + self.history[3:] | |
| else: | |
| # still too long even after pruning – fallback | |
| self.history = [self.history[0]] + self.history[-2:] | |
| with torch.no_grad(): | |
| output_ids = self.client_or_model.generate( | |
| input_ids, | |
| max_new_tokens=1024, # following the original config | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.9, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| response = self.tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True) | |
| self.history.append({ | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| return response | |
| else: # human_interview | |
| response = input(f"Your Response: ") | |
| return response | |
| async def _setup_client_and_chat(self, user_id: str, char_id: str ): | |
| try: | |
| self.client_or_model = await get_client(user_id) | |
| me_task = asyncio.create_task(self.client_or_model.account.fetch_me()) | |
| chat_task = asyncio.create_task(self.client_or_model.chat.create_chat(char_id)) | |
| me, (chat, greeting_message) = await asyncio.gather(me_task, chat_task) | |
| self.chat_id = chat.chat_id | |
| except Exception as e: | |
| logging.error(f"Failed to set up the cai client: {e}") | |
| await self.client_or_model.close_session() | |
| def clear_model(self): | |
| if self.type == "opencharacter": # no action needed for other types | |
| del self.client_or_model | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| async def close(self): | |
| if self.type == "characterai": | |
| if hasattr(self, 'client_or_model') and self.client_or_model is not None: | |
| await self.client_or_model.close_session() | |
| self.client_or_model = None | |
| self.chat_id = None | |