Patryk Studzinski commited on
Commit
eaa2e37
·
1 Parent(s): ab2e415

Fix: Use direct model.generate() with proper KV caching instead of pipeline

Browse files
Files changed (1) hide show
  1. app/models/huggingface_local.py +50 -79
app/models/huggingface_local.py CHANGED
@@ -63,12 +63,11 @@ class HuggingFaceLocal(BaseLLM):
63
  # Model config optimizations
64
  model_kwargs = {
65
  "trust_remote_code": True,
66
- "use_cache": self.use_cache, # Enable KV caching
67
  "torch_dtype": self.torch_dtype,
68
  }
69
 
70
- # Enable flash attention if requested and available
71
- if self.use_flash_attention:
72
  model_kwargs["attn_implementation"] = "flash_attention_2"
73
 
74
  self.model = await asyncio.to_thread(
@@ -78,48 +77,16 @@ class HuggingFaceLocal(BaseLLM):
78
  **model_kwargs
79
  )
80
 
81
- # Create pipeline with optimized model
82
- self.pipeline = await asyncio.to_thread(
83
- pipeline,
84
- "text-generation",
85
- model=self.model,
86
- tokenizer=self.tokenizer,
87
- device=self.device_index,
88
- )
89
 
90
  self._initialized = True
91
- print(f"[{self.name}] Model loaded successfully with KV caching enabled")
92
 
93
  except Exception as e:
94
  print(f"[{self.name}] Failed to load model: {e}")
95
- # Fallback: try without flash attention
96
- if self.use_flash_attention:
97
- print(f"[{self.name}] Retrying without flash attention...")
98
- self.use_flash_attention = False
99
- try:
100
- self.tokenizer = await asyncio.to_thread(
101
- AutoTokenizer.from_pretrained,
102
- self.model_id,
103
- trust_remote_code=True
104
- )
105
-
106
- self.pipeline = await asyncio.to_thread(
107
- pipeline,
108
- "text-generation",
109
- model=self.model_id,
110
- tokenizer=self.tokenizer,
111
- device=self.device_index,
112
- torch_dtype=self.torch_dtype,
113
- trust_remote_code=True,
114
- use_cache=self.use_cache,
115
- )
116
- self._initialized = True
117
- print(f"[{self.name}] Model loaded successfully (without flash attention)")
118
- except Exception as e2:
119
- print(f"[{self.name}] Fallback also failed: {e2}")
120
- raise
121
- else:
122
- raise
123
 
124
  async def generate(
125
  self,
@@ -131,14 +98,14 @@ class HuggingFaceLocal(BaseLLM):
131
  **kwargs
132
  ) -> str:
133
  """
134
- Generate text using local pipeline with KV cache optimizations.
135
 
136
- KV Cache Impact:
137
- - WITH: ~9 seconds for 10 ads (50 gaps total)
138
  - WITHOUT: ~42 seconds (4.7x slower)
139
  """
140
 
141
- if not self._initialized:
142
  raise RuntimeError(f"[{self.name}] Model not initialized")
143
 
144
  formatted_prompt = None
@@ -153,55 +120,59 @@ class HuggingFaceLocal(BaseLLM):
153
  )
154
  except Exception as e:
155
  print(f"[{self.name}] apply_chat_template failed: {e}, using fallback")
156
- # Fallback: manually format chat messages
157
  formatted_prompt = self._format_chat_fallback(chat_messages)
158
 
159
- # Use raw prompt if provided and no chat_messages
160
  if formatted_prompt is None and prompt:
161
  formatted_prompt = prompt
162
 
163
  if formatted_prompt is None:
164
  raise ValueError("Either prompt or chat_messages required")
165
 
166
- # Generate with KV cache and optimizations
167
- # The pipeline uses use_cache=True internally when initialized
168
- generation_kwargs = {
169
- "max_new_tokens": max_new_tokens,
170
- "do_sample": True,
171
- "temperature": temperature,
172
- "top_p": top_p,
173
- "eos_token_id": self.tokenizer.eos_token_id,
174
- "pad_token_id": self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id,
175
- }
176
 
177
- # If using direct model (not pipeline), enable return_dict_in_generate for better caching
178
- if hasattr(self, 'model') and self.model is not None:
179
- generation_kwargs["return_dict_in_generate"] = True
180
 
 
181
  outputs = await asyncio.to_thread(
182
- self.pipeline,
183
- formatted_prompt,
184
- **generation_kwargs
 
 
 
 
 
 
 
185
  )
186
 
187
- # Extract response
188
- if outputs and isinstance(outputs, list) and "generated_text" in outputs[0]:
189
- full_text = outputs[0]["generated_text"]
190
-
191
- # Remove prompt from output
192
- if full_text.startswith(formatted_prompt):
193
- response = full_text[len(formatted_prompt):]
194
- else:
195
- response = full_text
196
-
197
- # Clean up special tokens
198
- for token in ["<|im_end|>", "<end_of_turn>", "<eos>", "</s>"]:
199
- if response.endswith(token):
200
- response = response[:-len(token)]
201
-
202
- return response.strip()
 
203
 
204
- return ""
205
 
206
  def _format_chat_fallback(self, chat_messages: List[Dict[str, str]]) -> str:
207
  """
 
63
  # Model config optimizations
64
  model_kwargs = {
65
  "trust_remote_code": True,
 
66
  "torch_dtype": self.torch_dtype,
67
  }
68
 
69
+ # Enable flash attention if requested and available (GPU only)
70
+ if self.use_flash_attention and self.device == "cuda":
71
  model_kwargs["attn_implementation"] = "flash_attention_2"
72
 
73
  self.model = await asyncio.to_thread(
 
77
  **model_kwargs
78
  )
79
 
80
+ # Ensure cache is enabled on model config
81
+ if hasattr(self.model.config, 'use_cache'):
82
+ self.model.config.use_cache = self.use_cache
 
 
 
 
 
83
 
84
  self._initialized = True
85
+ print(f"[{self.name}] Model loaded successfully (use_cache={self.use_cache})")
86
 
87
  except Exception as e:
88
  print(f"[{self.name}] Failed to load model: {e}")
89
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  async def generate(
92
  self,
 
98
  **kwargs
99
  ) -> str:
100
  """
101
+ Generate text using direct model.generate() with proper KV caching.
102
 
103
+ KV Cache Impact (with proper implementation):
104
+ - WITH: ~9 seconds for 10 ads (50 gaps)
105
  - WITHOUT: ~42 seconds (4.7x slower)
106
  """
107
 
108
+ if not self._initialized or self.model is None:
109
  raise RuntimeError(f"[{self.name}] Model not initialized")
110
 
111
  formatted_prompt = None
 
120
  )
121
  except Exception as e:
122
  print(f"[{self.name}] apply_chat_template failed: {e}, using fallback")
 
123
  formatted_prompt = self._format_chat_fallback(chat_messages)
124
 
125
+ # Use raw prompt if provided
126
  if formatted_prompt is None and prompt:
127
  formatted_prompt = prompt
128
 
129
  if formatted_prompt is None:
130
  raise ValueError("Either prompt or chat_messages required")
131
 
132
+ # Tokenize input
133
+ inputs = await asyncio.to_thread(
134
+ self.tokenizer.encode,
135
+ formatted_prompt,
136
+ return_tensors="pt"
137
+ )
 
 
 
 
138
 
139
+ # Move to device
140
+ if self.device == "cuda":
141
+ inputs = await asyncio.to_thread(lambda: inputs.to("cuda"))
142
 
143
+ # Generate with explicit KV cache
144
  outputs = await asyncio.to_thread(
145
+ self.model.generate,
146
+ inputs,
147
+ max_new_tokens=max_new_tokens,
148
+ do_sample=True,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ use_cache=True, # CRITICAL: Enable KV cache
152
+ use_xformers_attention=False, # CPU doesn't support this
153
+ eos_token_id=self.tokenizer.eos_token_id,
154
+ pad_token_id=self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id,
155
  )
156
 
157
+ # Decode output
158
+ output_text = await asyncio.to_thread(
159
+ self.tokenizer.decode,
160
+ outputs[0],
161
+ skip_special_tokens=True
162
+ )
163
+
164
+ # Remove prompt from output
165
+ if output_text.startswith(formatted_prompt):
166
+ response = output_text[len(formatted_prompt):]
167
+ else:
168
+ response = output_text
169
+
170
+ # Clean up special tokens
171
+ for token in ["<|im_end|>", "<end_of_turn>", "<eos>", "</s>"]:
172
+ if response.endswith(token):
173
+ response = response[:-len(token)]
174
 
175
+ return response.strip()
176
 
177
  def _format_chat_fallback(self, chat_messages: List[Dict[str, str]]) -> str:
178
  """