File size: 8,857 Bytes
4f24301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import gc
from typing import Tuple, Dict, Any, Optional, List, Generator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
from threading import Thread

from deepforest_agent.conf.config import Config


class Llama32ModelManager:
    """
    Manages Llama-3.2-3B-Instruct model instances for text generation tasks.
    
    Attributes:
        model_id (str): HuggingFace model identifier
        load_count (int): Number of times model has been loaded
    """
    
    def __init__(self, model_id: str = Config.AGENT_MODELS["ecology_analysis"]):
        """
        Initialize the Llama-3.2-3B model manager.
        
        Args:
            model_id (str, optional): HuggingFace model identifier. 
                                    Defaults to "meta-llama/Llama-3.2-3B-Instruct".
        """
        self.model_id = model_id
        self.load_count = 0
    
    def generate_response(
        self, 
        messages: List[Dict[str, str]],
        max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
        temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
        top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
        tools: Optional[List[Dict[str, Any]]] = None
    ) -> str:
        """
        Generate text response using Llama-3.2-3B-Instruct.
        
        Args:
            messages: List of message dictionaries with 'role' and 'content'
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            top_p: Top-p sampling
            tools (Optional[List[Dict[str, Any]]]): List of tools (not used for Llama)
            
        Returns:
            str: Generated response text
            
        Raises:
            Exception: If generation fails due to model issues, memory, or other errors
        """
        print(f"Loading Llama-3.2-3B for inference #{self.load_count + 1}")

        model, tokenizer = self._load_model()
        self.load_count += 1

        try:
            # Llama uses standard chat template without xml_tools
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

            model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

            generated_ids = model.generate(
                model_inputs.input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )

            generated_ids = [
                output_ids[len(input_ids):] 
                for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]

            response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
            return response
            
        except Exception as e:
            print(f"Error during Llama-3.2-3B text generation: {e}")
            raise e
            
        finally:
            print(f"Releasing Llama-3.2-3B GPU memory after inference")
            if 'model' in locals():
                if hasattr(model, 'cpu'):
                    model.cpu()
                del model
            if 'tokenizer' in locals():
                del tokenizer
            if 'model_inputs' in locals():
                del model_inputs
            if 'generated_ids' in locals():
                del generated_ids
            
            # Multiple garbage collection passes
            for _ in range(3):
                gc.collect()
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect() 
                torch.cuda.synchronize()
                try:
                    torch.cuda.memory._record_memory_history(enabled=None)
                except:
                    pass
                print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")
    
    def generate_response_streaming(
        self, 
        messages: List[Dict[str, str]],
        max_new_tokens: int = Config.AGENT_CONFIGS["ecology_analysis"]["max_new_tokens"],
        temperature: float = Config.AGENT_CONFIGS["ecology_analysis"]["temperature"],
        top_p: float = Config.AGENT_CONFIGS["ecology_analysis"]["top_p"],
    ) -> Generator[Dict[str, Any], None, None]:
        """
        Generate text response with streaming (token by token).
        
        Args:
            messages: List of message dictionaries with 'role' and 'content'
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            top_p: Top-p sampling
            
        Yields:
            Dict[str, Any]: Dictionary containing:
                - token: The generated token/text chunk
                - is_complete: Whether generation is finished
                
        Raises:
            Exception: If generation fails due to model issues, memory, or other errors
        """
        print(f"Loading Llama-3.2-3B for streaming inference #{self.load_count + 1}")

        model, tokenizer = self._load_model()
        self.load_count += 1
        
        try:
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

            model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

            streamer = TextIteratorStreamer(
                tokenizer, 
                timeout=60.0, 
                skip_prompt=True, 
                skip_special_tokens=True
            )

            generation_kwargs = {
                "input_ids": model_inputs.input_ids,
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": True,
                "pad_token_id": tokenizer.eos_token_id,
                "streamer": streamer
            }

            thread = Thread(target=model.generate, kwargs=generation_kwargs)
            thread.start()

            for new_text in streamer:
                yield {"token": new_text, "is_complete": False}

            thread.join()
            yield {"token": "", "is_complete": True}
            
        except Exception as e:
            print(f"Error during Llama-3.2-3B streaming generation: {e}")
            yield {"token": f"[Error: {str(e)}]", "is_complete": True}
            
        finally:
            print(f"Releasing Llama-3.2-3B GPU memory after inference")
            if 'model' in locals():
                if hasattr(model, 'cpu'):
                    model.cpu()
                del model
            if 'tokenizer' in locals():
                del tokenizer
            if 'model_inputs' in locals():
                del model_inputs
            if 'generated_ids' in locals():
                del generated_ids
            
            # Multiple garbage collection passes
            for _ in range(3):
                gc.collect()
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect() 
                torch.cuda.synchronize()
                try:
                    torch.cuda.memory._record_memory_history(enabled=None)
                except:
                    pass
                print(f"GPU memory after aggressive cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated, {torch.cuda.memory_reserved() / 1024**3:.2f} GB cached")

    def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
        """
        Private method for model and tokenizer loading.
        
        Returns:
            Tuple[AutoModelForCausalLM, AutoTokenizer]: Loaded model and tokenizer
            
        Raises:
            Exception: If model loading fails due to network, memory, or other issues
        """
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                self.model_id,
                trust_remote_code=True
            )
            
            # Llama models may need specific configurations
            model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                torch_dtype="auto",
                device_map="auto",
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            return model, tokenizer
            
        except Exception as e:
            print(f"Error loading Llama-3.2-3B model: {e}")
            raise e