import os import logging import torch from typing import Dict, Optional, List, Union from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from sentence_transformers import SentenceTransformer # Configure logging for Hugging Face Spaces logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("TxAgent") class TxAgent: def __init__(self, model_name: str, rag_model_name: str, tool_files_dict: Optional[Dict] = None, enable_finish: bool = True, enable_rag: bool = False, force_finish: bool = True, enable_checker: bool = True, step_rag_num: int = 4, seed: Optional[int] = None): # Initialization parameters self.model_name = model_name self.rag_model_name = rag_model_name self.tool_files_dict = tool_files_dict or {} self.enable_finish = enable_finish self.enable_rag = enable_rag self.force_finish = force_finish self.enable_checker = enable_checker self.step_rag_num = step_rag_num self.seed = seed # Device setup self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Models self.model = None self.tokenizer = None self.rag_model = None # Prompts self.chat_prompt = "You are a helpful assistant for user chat." logger.info(f"Initialized TxAgent with model: {model_name}") def init_model(self): """Initialize all models and components""" try: self.load_llm_model() if self.enable_rag: self.load_rag_model() logger.info("Models initialized successfully") except Exception as e: logger.error(f"Model initialization failed: {str(e)}") raise def load_llm_model(self): """Load the main LLM model""" try: logger.info(f"Loading LLM model: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto", trust_remote_code=True ) logger.info(f"LLM model loaded on {self.device}") except Exception as e: logger.error(f"Failed to load LLM model: {str(e)}") raise def load_rag_model(self): """Load the RAG model""" try: logger.info(f"Loading RAG model: {self.rag_model_name}") self.rag_model = SentenceTransformer( self.rag_model_name, device=str(self.device) ) logger.info("RAG model loaded successfully") except Exception as e: logger.error(f"Failed to load RAG model: {str(e)}") raise def chat(self, message: str, history: Optional[List[Dict]] = None, temperature: float = 0.7, max_new_tokens: int = 512) -> str: """Handle chat conversations""" try: conversation = [] # Enhanced system prompt for better clinical responses enhanced_prompt = f"{self.chat_prompt} Provide comprehensive, well-structured responses with clear sections. Use markdown formatting for better readability. Always give complete, actionable information." conversation.append({"role": "system", "content": enhanced_prompt}) # Add history if provided if history: for msg in history: conversation.append({"role": msg["role"], "content": msg["content"]}) # Add current message with context enhanced_message = f"Please provide a comprehensive answer about: {message}. Structure your response with clear sections and use markdown formatting." conversation.append({"role": "user", "content": enhanced_message}) # Generate response inputs = self.tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(self.device) generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1, # Prevent repetitive text top_p=0.9, # Nucleus sampling for better quality top_k=50 # Top-k sampling ) outputs = self.model.generate( inputs, generation_config=generation_config ) # Decode and clean up response response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) # Clean and structure the response cleaned_response = response.strip() # If response is too short, enhance it if len(cleaned_response) < 100: cleaned_response = f"Based on your question about '{message}', here is a comprehensive answer:\n\n{cleaned_response}\n\nThis information should help you understand the topic better. If you need more specific details, please ask follow-up questions." return cleaned_response except Exception as e: logger.error(f"Chat failed: {str(e)}") raise RuntimeError(f"Chat failed: {str(e)}") def cleanup(self): """Clean up resources""" try: if hasattr(self, 'model'): del self.model if hasattr(self, 'rag_model'): del self.rag_model torch.cuda.empty_cache() logger.info("Resources cleaned up") except Exception as e: logger.error(f"Cleanup failed: {str(e)}") raise def __del__(self): """Destructor to ensure proper cleanup""" self.cleanup()