Patryk Studzinski commited on
Commit
42e3538
·
1 Parent(s): 80f351c

Fix gemma chat template fallback

Browse files
Files changed (1) hide show
  1. app/models/huggingface_local.py +36 -11
app/models/huggingface_local.py CHANGED
@@ -73,7 +73,9 @@ class HuggingFaceLocal(BaseLLM):
73
  if not self._initialized:
74
  raise RuntimeError(f"[{self.name}] Model not initialized")
75
 
76
- # Format prompt
 
 
77
  if chat_messages:
78
  try:
79
  formatted_prompt = self.tokenizer.apply_chat_template(
@@ -81,15 +83,16 @@ class HuggingFaceLocal(BaseLLM):
81
  tokenize=False,
82
  add_generation_prompt=True
83
  )
84
- except Exception:
85
- # Fallback to raw prompt
86
- if prompt:
87
- formatted_prompt = prompt
88
- else:
89
- raise ValueError("Cannot generate without valid prompt")
90
- elif prompt:
91
  formatted_prompt = prompt
92
- else:
 
93
  raise ValueError("Either prompt or chat_messages required")
94
 
95
  # Generate
@@ -115,13 +118,35 @@ class HuggingFaceLocal(BaseLLM):
115
  response = full_text
116
 
117
  # Clean up special tokens
118
- if response.endswith("<|im_end|>"):
119
- response = response[:-len("<|im_end|>")]
 
120
 
121
  return response.strip()
122
 
123
  return ""
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def get_info(self) -> Dict[str, Any]:
126
  """Return model info."""
127
  return {
 
73
  if not self._initialized:
74
  raise RuntimeError(f"[{self.name}] Model not initialized")
75
 
76
+ formatted_prompt = None
77
+
78
+ # Format prompt from chat messages
79
  if chat_messages:
80
  try:
81
  formatted_prompt = self.tokenizer.apply_chat_template(
 
83
  tokenize=False,
84
  add_generation_prompt=True
85
  )
86
+ except Exception as e:
87
+ print(f"[{self.name}] apply_chat_template failed: {e}, using fallback")
88
+ # Fallback: manually format chat messages
89
+ formatted_prompt = self._format_chat_fallback(chat_messages)
90
+
91
+ # Use raw prompt if provided and no chat_messages
92
+ if formatted_prompt is None and prompt:
93
  formatted_prompt = prompt
94
+
95
+ if formatted_prompt is None:
96
  raise ValueError("Either prompt or chat_messages required")
97
 
98
  # Generate
 
118
  response = full_text
119
 
120
  # Clean up special tokens
121
+ for token in ["<|im_end|>", "<end_of_turn>", "<eos>", "</s>"]:
122
+ if response.endswith(token):
123
+ response = response[:-len(token)]
124
 
125
  return response.strip()
126
 
127
  return ""
128
 
129
+ def _format_chat_fallback(self, chat_messages: List[Dict[str, str]]) -> str:
130
+ """
131
+ Fallback chat formatting for models without proper chat template.
132
+ Works with Gemma and other models.
133
+ """
134
+ formatted = ""
135
+ for msg in chat_messages:
136
+ role = msg.get("role", "user")
137
+ content = msg.get("content", "")
138
+
139
+ if role == "system":
140
+ formatted += f"{content}\n\n"
141
+ elif role == "user":
142
+ formatted += f"User: {content}\n"
143
+ elif role == "assistant":
144
+ formatted += f"Assistant: {content}\n"
145
+
146
+ # Add generation prompt
147
+ formatted += "Assistant:"
148
+ return formatted
149
+
150
  def get_info(self) -> Dict[str, Any]:
151
  """Return model info."""
152
  return {