cps-api-tx / src /txagent.py
Ali2206's picture
Integrate TxAgent into CPS API: add TxAgent files, update requirements, fix imports and database collections
2e659cd
import os
import logging
import torch
from typing import Dict, Optional, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer
# Configure logging for Hugging Face Spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgent")
class TxAgent:
def __init__(self,
model_name: str,
rag_model_name: str,
tool_files_dict: Optional[Dict] = None,
enable_finish: bool = True,
enable_rag: bool = False,
force_finish: bool = True,
enable_checker: bool = True,
step_rag_num: int = 4,
seed: Optional[int] = None):
# Initialization parameters
self.model_name = model_name
self.rag_model_name = rag_model_name
self.tool_files_dict = tool_files_dict or {}
self.enable_finish = enable_finish
self.enable_rag = enable_rag
self.force_finish = force_finish
self.enable_checker = enable_checker
self.step_rag_num = step_rag_num
self.seed = seed
# Device setup
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Models
self.model = None
self.tokenizer = None
self.rag_model = None
# Prompts
self.chat_prompt = "You are a helpful assistant for user chat."
logger.info(f"Initialized TxAgent with model: {model_name}")
def init_model(self):
"""Initialize all models and components"""
try:
self.load_llm_model()
if self.enable_rag:
self.load_rag_model()
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def load_llm_model(self):
"""Load the main LLM model"""
try:
logger.info(f"Loading LLM model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto",
trust_remote_code=True
)
logger.info(f"LLM model loaded on {self.device}")
except Exception as e:
logger.error(f"Failed to load LLM model: {str(e)}")
raise
def load_rag_model(self):
"""Load the RAG model"""
try:
logger.info(f"Loading RAG model: {self.rag_model_name}")
self.rag_model = SentenceTransformer(
self.rag_model_name,
device=str(self.device)
)
logger.info("RAG model loaded successfully")
except Exception as e:
logger.error(f"Failed to load RAG model: {str(e)}")
raise
def chat(self, message: str, history: Optional[List[Dict]] = None,
temperature: float = 0.7, max_new_tokens: int = 512) -> str:
"""Handle chat conversations"""
try:
conversation = []
# Enhanced system prompt for better clinical responses
enhanced_prompt = f"{self.chat_prompt} Provide comprehensive, well-structured responses with clear sections. Use markdown formatting for better readability. Always give complete, actionable information."
conversation.append({"role": "system", "content": enhanced_prompt})
# Add history if provided
if history:
for msg in history:
conversation.append({"role": msg["role"], "content": msg["content"]})
# Add current message with context
enhanced_message = f"Please provide a comprehensive answer about: {message}. Structure your response with clear sections and use markdown formatting."
conversation.append({"role": "user", "content": enhanced_message})
# Generate response
inputs = self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(self.device)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1, # Prevent repetitive text
top_p=0.9, # Nucleus sampling for better quality
top_k=50 # Top-k sampling
)
outputs = self.model.generate(
inputs,
generation_config=generation_config
)
# Decode and clean up response
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
# Clean and structure the response
cleaned_response = response.strip()
# If response is too short, enhance it
if len(cleaned_response) < 100:
cleaned_response = f"Based on your question about '{message}', here is a comprehensive answer:\n\n{cleaned_response}\n\nThis information should help you understand the topic better. If you need more specific details, please ask follow-up questions."
return cleaned_response
except Exception as e:
logger.error(f"Chat failed: {str(e)}")
raise RuntimeError(f"Chat failed: {str(e)}")
def cleanup(self):
"""Clean up resources"""
try:
if hasattr(self, 'model'):
del self.model
if hasattr(self, 'rag_model'):
del self.rag_model
torch.cuda.empty_cache()
logger.info("Resources cleaned up")
except Exception as e:
logger.error(f"Cleanup failed: {str(e)}")
raise
def __del__(self):
"""Destructor to ensure proper cleanup"""
self.cleanup()