Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import torch | |
| from typing import Dict, Optional, List, Union | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
| from sentence_transformers import SentenceTransformer | |
| # Configure logging for Hugging Face Spaces | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("TxAgent") | |
| class TxAgent: | |
| def __init__(self, | |
| model_name: str, | |
| rag_model_name: str, | |
| tool_files_dict: Optional[Dict] = None, | |
| enable_finish: bool = True, | |
| enable_rag: bool = False, | |
| force_finish: bool = True, | |
| enable_checker: bool = True, | |
| step_rag_num: int = 4, | |
| seed: Optional[int] = None): | |
| # Initialization parameters | |
| self.model_name = model_name | |
| self.rag_model_name = rag_model_name | |
| self.tool_files_dict = tool_files_dict or {} | |
| self.enable_finish = enable_finish | |
| self.enable_rag = enable_rag | |
| self.force_finish = force_finish | |
| self.enable_checker = enable_checker | |
| self.step_rag_num = step_rag_num | |
| self.seed = seed | |
| # Device setup | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Models | |
| self.model = None | |
| self.tokenizer = None | |
| self.rag_model = None | |
| # Prompts | |
| self.chat_prompt = "You are a helpful assistant for user chat." | |
| logger.info(f"Initialized TxAgent with model: {model_name}") | |
| def init_model(self): | |
| """Initialize all models and components""" | |
| try: | |
| self.load_llm_model() | |
| if self.enable_rag: | |
| self.load_rag_model() | |
| logger.info("Models initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Model initialization failed: {str(e)}") | |
| raise | |
| def load_llm_model(self): | |
| """Load the main LLM model""" | |
| try: | |
| logger.info(f"Loading LLM model: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"LLM model loaded on {self.device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load LLM model: {str(e)}") | |
| raise | |
| def load_rag_model(self): | |
| """Load the RAG model""" | |
| try: | |
| logger.info(f"Loading RAG model: {self.rag_model_name}") | |
| self.rag_model = SentenceTransformer( | |
| self.rag_model_name, | |
| device=str(self.device) | |
| ) | |
| logger.info("RAG model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load RAG model: {str(e)}") | |
| raise | |
| def chat(self, message: str, history: Optional[List[Dict]] = None, | |
| temperature: float = 0.7, max_new_tokens: int = 512) -> str: | |
| """Handle chat conversations""" | |
| try: | |
| conversation = [] | |
| # Initialize with system prompt | |
| conversation.append({"role": "system", "content": self.chat_prompt}) | |
| # Add history if provided | |
| if history: | |
| for msg in history: | |
| conversation.append({"role": msg["role"], "content": msg["content"]}) | |
| # Add current message | |
| conversation.append({"role": "user", "content": message}) | |
| # Generate response | |
| inputs = self.tokenizer.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| generation_config = GenerationConfig( | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| outputs = self.model.generate( | |
| inputs, | |
| generation_config=generation_config | |
| ) | |
| # Decode and clean up response | |
| response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| return response.strip() | |
| except Exception as e: | |
| logger.error(f"Chat failed: {str(e)}") | |
| raise RuntimeError(f"Chat failed: {str(e)}") | |
| def cleanup(self): | |
| """Clean up resources""" | |
| try: | |
| if hasattr(self, 'model'): | |
| del self.model | |
| if hasattr(self, 'rag_model'): | |
| del self.rag_model | |
| torch.cuda.empty_cache() | |
| logger.info("Resources cleaned up") | |
| except Exception as e: | |
| logger.error(f"Cleanup failed: {str(e)}") | |
| raise | |
| def __del__(self): | |
| """Destructor to ensure proper cleanup""" | |
| self.cleanup() |