navidfalah commited on
Commit
e12356b
·
verified ·
1 Parent(s): 5f925e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -134
app.py CHANGED
@@ -98,181 +98,137 @@ except Exception as e:
98
  if tokenizer.pad_token is None:
99
  tokenizer.pad_token = tokenizer.eos_token
100
 
101
- print(f"Loading your fine-tuned Mistral model from {model_path}...")
102
  try:
103
- # Load your fine-tuned model weights
104
- model = AutoModelForCausalLM.from_pretrained(
105
- model_path,
106
  torch_dtype=torch.float16,
107
  device_map="auto",
108
- trust_remote_code=True,
109
- low_cpu_mem_usage=True,
110
- local_files_only=True
111
  )
112
- print("Fine-tuned Mistral model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  except Exception as e:
115
- print(f"Error loading fine-tuned model from {model_path}: {e}")
116
- print("Trying without local_files_only...")
117
  try:
118
- model = AutoModelForCausalLM.from_pretrained(
119
- model_path,
 
120
  torch_dtype=torch.float16,
121
  device_map="auto",
122
- trust_remote_code=True,
123
  low_cpu_mem_usage=True
124
  )
125
- print("Fine-tuned Mistral model loaded successfully!")
 
126
  except Exception as e2:
127
- print(f"Cannot load fine-tuned model: {e2}")
128
- print("Exiting - cannot proceed without your fine-tuned model")
129
  exit(1)
130
 
131
  def chat_function(message):
132
  if not message or not message.strip():
133
- return "Please enter a message to get started!"
134
 
135
- # Limit input length
136
- if len(message) > 300:
137
- return "Message too long! Please keep it under 300 characters."
 
138
 
139
  try:
140
  # Use flexible prompt format based on tokenizer type
141
  if hasattr(tokenizer, 'chat_template') or 'mistral' in tokenizer.name_or_path.lower():
142
  # Use Mistral format if it's actually Mistral
143
- prompt = f"<s>[INST] {message.strip()} [/INST]"
144
  else:
145
  # Use simple format for other tokenizers
146
- prompt = f"Human: {message.strip()}\nAssistant:"
147
 
148
  # Tokenize input
149
- try:
150
- inputs = tokenizer(
151
- prompt,
152
- return_tensors='pt',
153
- truncation=True,
154
- max_length=512,
155
- padding=True
156
- )
157
- input_ids = inputs['input_ids']
158
- attention_mask = inputs.get('attention_mask', None)
159
-
160
- except Exception as e:
161
- print(f"Tokenization error: {e}")
162
- return f"Error processing your message: {str(e)}"
163
-
164
- # Validate input
165
- if input_ids.shape[-1] == 0:
166
- return "Error: Empty input after encoding"
167
-
168
- print(f"Input shape: {input_ids.shape}")
169
 
170
  # Move to model device
171
- try:
172
- device = next(model.parameters()).device
173
- input_ids = input_ids.to(device)
174
- if attention_mask is not None:
175
- attention_mask = attention_mask.to(device)
176
- except Exception as e:
177
- print(f"Device move error: {e}")
178
 
179
  # Generate response
180
- try:
181
- with torch.no_grad():
182
- # Clear cache to prevent memory issues
183
- if torch.cuda.is_available():
184
- torch.cuda.empty_cache()
185
-
186
- # Conservative generation parameters
187
- generation_kwargs = {
188
- 'input_ids': input_ids,
189
- 'max_new_tokens': 150,
190
- 'temperature': 0.7,
191
- 'do_sample': True,
192
- 'pad_token_id': tokenizer.pad_token_id,
193
- 'eos_token_id': tokenizer.eos_token_id,
194
- 'num_return_sequences': 1,
195
- 'repetition_penalty': 1.1,
196
- 'top_p': 0.9,
197
- 'use_cache': True,
198
- 'num_beams': 1,
199
- }
200
-
201
- # Add attention mask if available
202
- if attention_mask is not None:
203
- generation_kwargs['attention_mask'] = attention_mask
204
-
205
- print(f"Generating with input_ids shape: {input_ids.shape}")
206
- outputs = model.generate(**generation_kwargs)
207
- print(f"Generated output shape: {outputs.shape}")
208
-
209
- except Exception as e:
210
- print(f"Generation error: {e}")
211
- # Try with minimal settings
212
- try:
213
- print("Trying with minimal settings...")
214
- outputs = model.generate(
215
- input_ids,
216
- max_new_tokens=80,
217
- do_sample=False, # Greedy decoding
218
- pad_token_id=tokenizer.pad_token_id,
219
- eos_token_id=tokenizer.eos_token_id,
220
- )
221
- print(f"Minimal generation output shape: {outputs.shape}")
222
- except Exception as e2:
223
- print(f"Minimal generation also failed: {e2}")
224
- return f"Error generating response: {str(e)}"
225
-
226
- # Decode response
227
- try:
228
- # Extract only the new tokens (response part)
229
- if outputs.shape[1] > input_ids.shape[1]:
230
- response_ids = outputs[0][input_ids.shape[1]:]
231
- response = tokenizer.decode(response_ids, skip_special_tokens=True)
232
- else:
233
- # Fallback: decode full output and remove prompt
234
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
235
- response = full_response.replace(prompt.replace("<s>", "").replace("</s>", ""), "").strip()
236
 
237
- except Exception as e:
238
- print(f"Decoding error: {e}")
239
- try:
240
- # Last resort: decode full output
241
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
242
- response = full_response
243
- except:
244
- return f"Error decoding response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- # Clean up the response based on tokenizer type
247
  response = response.strip()
248
 
249
- # Remove prompt artifacts based on what we used
250
- if "[/INST]" in response:
251
- response = response.split("[/INST]")[-1].strip()
252
- if "[INST]" in response:
253
- response = response.split("[INST]")[0].strip()
254
- if "Assistant:" in response:
255
- response = response.split("Assistant:")[-1].strip()
256
- if "Human:" in response:
257
- response = response.split("Human:")[0].strip()
258
 
259
- # Remove input message if it appears in response
260
- if message.strip() in response:
261
- response = response.replace(message.strip(), "").strip()
262
 
263
- # Limit response length
264
- if len(response) > 1000:
265
- response = response[:1000] + "..."
266
 
267
- # Ensure we have a meaningful response
268
- if len(response.strip()) < 5:
269
- response = "I understand your message. How can I help you with that?"
270
 
271
  return response
272
 
273
  except Exception as e:
274
- print(f"Unexpected error: {e}")
275
- return f"Sorry, I encountered an unexpected error: {str(e)}"
276
 
277
  def clear_chat():
278
  return ""
 
98
  if tokenizer.pad_token is None:
99
  tokenizer.pad_token = tokenizer.eos_token
100
 
101
+ print(f"Loading your model from {original_mistral_model}...")
102
  try:
103
+ # Load your model from Hugging Face
104
+ base_model = AutoModelForCausalLM.from_pretrained(
105
+ original_mistral_model,
106
  torch_dtype=torch.float16,
107
  device_map="auto",
108
+ low_cpu_mem_usage=True
 
 
109
  )
110
+ print("Your model loaded successfully!")
111
+
112
+ # Check if PEFT is available and try to load local adapter
113
+ if PeftModel is not None and PeftConfig is not None:
114
+ try:
115
+ print(f"Trying to load local LoRA adapter from {adapter_path}...")
116
+ model = PeftModel.from_pretrained(
117
+ base_model,
118
+ adapter_path,
119
+ torch_dtype=torch.float16
120
+ )
121
+ print("Local LoRA adapter loaded successfully!")
122
+ except Exception as adapter_error:
123
+ print(f"Could not load local adapter: {adapter_error}")
124
+ print("Using your base model without additional adapter")
125
+ model = base_model
126
+ else:
127
+ print("PEFT not available - using your base model")
128
+ model = base_model
129
 
130
  except Exception as e:
131
+ print(f"Error loading your model: {e}")
132
+ print("Trying to load original Mistral as fallback...")
133
  try:
134
+ # Fallback to original Mistral
135
+ base_model = AutoModelForCausalLM.from_pretrained(
136
+ "mistralai/Mistral-7B-Instruct-v0.1",
137
  torch_dtype=torch.float16,
138
  device_map="auto",
 
139
  low_cpu_mem_usage=True
140
  )
141
+ print("Fallback Mistral model loaded!")
142
+ model = base_model
143
  except Exception as e2:
144
+ print(f"Cannot load any model: {e2}")
145
+ print("Exiting - cannot proceed without model")
146
  exit(1)
147
 
148
  def chat_function(message):
149
  if not message or not message.strip():
150
+ return "Please enter a message."
151
 
152
+ # Clean and limit input
153
+ message = message.strip()
154
+ if len(message) > 500:
155
+ return "Message too long! Please keep it under 500 characters."
156
 
157
  try:
158
  # Use flexible prompt format based on tokenizer type
159
  if hasattr(tokenizer, 'chat_template') or 'mistral' in tokenizer.name_or_path.lower():
160
  # Use Mistral format if it's actually Mistral
161
+ prompt = f"<s>[INST] {message} [/INST]"
162
  else:
163
  # Use simple format for other tokenizers
164
+ prompt = f"User: {message}\nAssistant:"
165
 
166
  # Tokenize input
167
+ inputs = tokenizer(
168
+ prompt,
169
+ return_tensors='pt',
170
+ truncation=True,
171
+ max_length=400,
172
+ padding=True
173
+ )
174
+ input_ids = inputs['input_ids']
175
+ attention_mask = inputs.get('attention_mask', None)
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Move to model device
178
+ device = next(model.parameters()).device
179
+ input_ids = input_ids.to(device)
180
+ if attention_mask is not None:
181
+ attention_mask = attention_mask.to(device)
 
 
 
182
 
183
  # Generate response
184
+ with torch.no_grad():
185
+ if torch.cuda.is_available():
186
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ outputs = model.generate(
189
+ input_ids,
190
+ max_new_tokens=200,
191
+ temperature=0.7,
192
+ do_sample=True,
193
+ top_p=0.9,
194
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
195
+ eos_token_id=tokenizer.eos_token_id,
196
+ attention_mask=attention_mask,
197
+ repetition_penalty=1.1
198
+ )
199
+
200
+ # Extract and clean response
201
+ if outputs.shape[1] > input_ids.shape[1]:
202
+ response_ids = outputs[0][input_ids.shape[1]:]
203
+ response = tokenizer.decode(response_ids, skip_special_tokens=True)
204
+ else:
205
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
206
+ response = response.replace(prompt, "").strip()
207
 
208
+ # Clean up response
209
  response = response.strip()
210
 
211
+ # Remove prompt artifacts
212
+ for artifact in ["[/INST]", "[INST]", "Assistant:", "User:", "Human:"]:
213
+ if artifact in response:
214
+ response = response.split(artifact)[-1].strip()
 
 
 
 
 
215
 
216
+ # Remove input if it appears in response
217
+ if message.lower() in response.lower():
218
+ response = response.replace(message, "").strip()
219
 
220
+ # Ensure reasonable length
221
+ if len(response) > 800:
222
+ response = response[:800] + "..."
223
 
224
+ # Fallback if empty
225
+ if len(response.strip()) < 3:
226
+ response = "I understand. How can I help you?"
227
 
228
  return response
229
 
230
  except Exception as e:
231
+ return f"Error: {str(e)}"
 
232
 
233
  def clear_chat():
234
  return ""