File size: 6,519 Bytes
b515e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e659cd
 
 
b515e8c
 
 
 
 
 
2e659cd
 
 
b515e8c
 
 
 
 
 
 
 
 
 
 
 
2e659cd
 
 
 
b515e8c
 
 
 
 
 
 
 
 
2e659cd
 
 
 
 
 
 
 
 
b515e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import logging
import torch
from typing import Dict, Optional, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer

# Configure logging for Hugging Face Spaces
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgent")

class TxAgent:
    def __init__(self, 
                 model_name: str,
                 rag_model_name: str,
                 tool_files_dict: Optional[Dict] = None,
                 enable_finish: bool = True,
                 enable_rag: bool = False,
                 force_finish: bool = True,
                 enable_checker: bool = True,
                 step_rag_num: int = 4,
                 seed: Optional[int] = None):
        
        # Initialization parameters
        self.model_name = model_name
        self.rag_model_name = rag_model_name
        self.tool_files_dict = tool_files_dict or {}
        self.enable_finish = enable_finish
        self.enable_rag = enable_rag
        self.force_finish = force_finish
        self.enable_checker = enable_checker
        self.step_rag_num = step_rag_num
        self.seed = seed
        
        # Device setup
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Models
        self.model = None
        self.tokenizer = None
        self.rag_model = None
        
        # Prompts
        self.chat_prompt = "You are a helpful assistant for user chat."
        
        logger.info(f"Initialized TxAgent with model: {model_name}")

    def init_model(self):
        """Initialize all models and components"""
        try:
            self.load_llm_model()
            if self.enable_rag:
                self.load_rag_model()
            logger.info("Models initialized successfully")
        except Exception as e:
            logger.error(f"Model initialization failed: {str(e)}")
            raise

    def load_llm_model(self):
        """Load the main LLM model"""
        try:
            logger.info(f"Loading LLM model: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                device_map="auto",
                trust_remote_code=True
            )
            logger.info(f"LLM model loaded on {self.device}")
        except Exception as e:
            logger.error(f"Failed to load LLM model: {str(e)}")
            raise

    def load_rag_model(self):
        """Load the RAG model"""
        try:
            logger.info(f"Loading RAG model: {self.rag_model_name}")
            self.rag_model = SentenceTransformer(
                self.rag_model_name,
                device=str(self.device)
            )
            logger.info("RAG model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load RAG model: {str(e)}")
            raise

    def chat(self, message: str, history: Optional[List[Dict]] = None, 
             temperature: float = 0.7, max_new_tokens: int = 512) -> str:
        """Handle chat conversations"""
        try:
            conversation = []
            
            # Enhanced system prompt for better clinical responses
            enhanced_prompt = f"{self.chat_prompt} Provide comprehensive, well-structured responses with clear sections. Use markdown formatting for better readability. Always give complete, actionable information."
            conversation.append({"role": "system", "content": enhanced_prompt})
            
            # Add history if provided
            if history:
                for msg in history:
                    conversation.append({"role": msg["role"], "content": msg["content"]})
            
            # Add current message with context
            enhanced_message = f"Please provide a comprehensive answer about: {message}. Structure your response with clear sections and use markdown formatting."
            conversation.append({"role": "user", "content": enhanced_message})
            
            # Generate response
            inputs = self.tokenizer.apply_chat_template(
                conversation,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.device)
            
            generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1,  # Prevent repetitive text
                top_p=0.9,  # Nucleus sampling for better quality
                top_k=50    # Top-k sampling
            )
            
            outputs = self.model.generate(
                inputs,
                generation_config=generation_config
            )
            
            # Decode and clean up response
            response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
            
            # Clean and structure the response
            cleaned_response = response.strip()
            
            # If response is too short, enhance it
            if len(cleaned_response) < 100:
                cleaned_response = f"Based on your question about '{message}', here is a comprehensive answer:\n\n{cleaned_response}\n\nThis information should help you understand the topic better. If you need more specific details, please ask follow-up questions."
            
            return cleaned_response
            
        except Exception as e:
            logger.error(f"Chat failed: {str(e)}")
            raise RuntimeError(f"Chat failed: {str(e)}")

    def cleanup(self):
        """Clean up resources"""
        try:
            if hasattr(self, 'model'):
                del self.model
            if hasattr(self, 'rag_model'):
                del self.rag_model
            torch.cuda.empty_cache()
            logger.info("Resources cleaned up")
        except Exception as e:
            logger.error(f"Cleanup failed: {str(e)}")
            raise

    def __del__(self):
        """Destructor to ensure proper cleanup"""
        self.cleanup()