SamiaHaque's picture
Adding files for initial deepforest-agent implementation
4f24301
raw
history blame
8.86 kB
import gc
from typing import Tuple, Dict, Any, Optional, List, Generator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
from threading import Thread
from deepforest_agent.conf.config import Config
class Llama32ModelManager:
"""
Manages Llama-3.2-3B-Instruct model instances for text generation tasks.
Attributes:
model_id (str): HuggingFace model identifier
load_count (int): Number of times model has been loaded
"""
def __init__(self, model_id: str = Config.AGENT_MODELS["ecology_analysis"]):
"""
Initialize the Llama-3.2-3B model manager.
Args:
model_id (str, optional): HuggingFace model identifier.
Defaults to "meta-llama/Llama-3.2-3B-Instruct".
"""
self.model_id = model_id
self.load_count = 0
def generate_response(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
tools: Optional[List[Dict[str, Any]]] = None
) -> str:
"""
Generate text response using Llama-3.2-3B-Instruct.
Args:
messages: List of message dictionaries with 'role' and 'content'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
tools (Optional[List[Dict[str, Any]]]): List of tools (not used for Llama)
Returns:
str: Generated response text
Raises:
Exception: If generation fails due to model issues, memory, or other errors
"""
print(f"Loading Llama-3.2-3B for inference #{self.load_count + 1}")
model, tokenizer = self._load_model()
self.load_count += 1
try:
# Llama uses standard chat template without xml_tools
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
print(f"Error during Llama-3.2-3B text generation: {e}")
raise e
finally:
print(f"Releasing Llama-3.2-3B GPU memory after inference")
if 'model' in locals():
if hasattr(model, 'cpu'):
model.cpu()
del model
if 'tokenizer' in locals():
del tokenizer
if 'model_inputs' in locals():
del model_inputs
if 'generated_ids' in locals():
del generated_ids
# Multiple garbage collection passes
for _ in range(3):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
try:
torch.cuda.memory._record_memory_history(enabled=None)
except:
pass
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
def generate_response_streaming(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
) -> Generator[Dict[str, Any], None, None]:
"""
Generate text response with streaming (token by token).
Args:
messages: List of message dictionaries with 'role' and 'content'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
Yields:
Dict[str, Any]: Dictionary containing:
- token: The generated token/text chunk
- is_complete: Whether generation is finished
Raises:
Exception: If generation fails due to model issues, memory, or other errors
"""
print(f"Loading Llama-3.2-3B for streaming inference #{self.load_count + 1}")
model, tokenizer = self._load_model()
self.load_count += 1
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = {
"input_ids": model_inputs.input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield {"token": new_text, "is_complete": False}
thread.join()
yield {"token": "", "is_complete": True}
except Exception as e:
print(f"Error during Llama-3.2-3B streaming generation: {e}")
yield {"token": f"[Error: {str(e)}]", "is_complete": True}
finally:
print(f"Releasing Llama-3.2-3B GPU memory after inference")
if 'model' in locals():
if hasattr(model, 'cpu'):
model.cpu()
del model
if 'tokenizer' in locals():
del tokenizer
if 'model_inputs' in locals():
del model_inputs
if 'generated_ids' in locals():
del generated_ids
# Multiple garbage collection passes
for _ in range(3):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
try:
torch.cuda.memory._record_memory_history(enabled=None)
except:
pass
print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Private method for model and tokenizer loading.
Returns:
Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
Raises:
Exception: If model loading fails due to network, memory, or other issues
"""
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
trust_remote_code=True
)
# Llama models may need specific configurations
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
return model, tokenizer
except Exception as e:
print(f"Error loading Llama-3.2-3B model: {e}")
raise e