Brianpuz commited on
Commit
8c92c76
·
1 Parent(s): fbb861f
Files changed (1) hide show
  1. app.py +80 -18
app.py CHANGED
@@ -316,23 +316,69 @@ class AbliterationProcessor:
316
  return_tensors="pt"
317
  )
318
 
319
- # Generate response with streaming like abliterated_optimized.py
320
- from transformers import TextStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- # Create a custom streamer that captures all output
323
- captured_output = []
 
 
 
 
324
 
325
- class CustomStreamer(TextStreamer):
 
 
 
326
  def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
327
  super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
328
- self.captured = []
329
 
330
  def on_finalized_text(self, text: str, stream_end: bool = False):
331
- self.captured.append(text)
332
- super().on_finalized_text(text, stream_end)
333
 
334
- streamer = CustomStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
335
 
 
336
  gen = self.model.generate(
337
  toks.to(self.model.device),
338
  max_new_tokens=2048,
@@ -342,12 +388,12 @@ class AbliterationProcessor:
342
  streamer=streamer
343
  )
344
 
345
- # Get the complete response from streamer
346
- response = "".join(streamer.captured).strip()
347
- return response, history + [[message, response]]
348
-
349
  except Exception as e:
350
- return f"❌ Chat error: {str(e)}", history
351
 
352
  def get_new_model_card(original_card: ModelCard, original_model_id: str, new_repo_url: str) -> ModelCard:
353
  """Create new model card"""
@@ -540,16 +586,32 @@ def create_interface():
540
 
541
  def bot(history):
542
  if history and history[-1]["role"] == "user":
543
- response, _ = processor.chat(history[-1]["content"], history[:-1])
544
- history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
545
  return history
546
 
 
 
 
 
 
 
 
 
 
 
547
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
548
- bot, chatbot, chatbot
549
  )
550
 
551
  send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
552
- bot, chatbot, chatbot
553
  )
554
 
555
  clear.click(lambda: [], None, chatbot, queue=False)
 
316
  return_tensors="pt"
317
  )
318
 
319
+ # Generate response without streaming for now (will be handled by Gradio)
320
+ gen = self.model.generate(
321
+ toks.to(self.model.device),
322
+ max_new_tokens=2048,
323
+ temperature=0.7,
324
+ do_sample=True,
325
+ pad_token_id=self.tokenizer.eos_token_id
326
+ )
327
+
328
+ # Decode response
329
+ decoded = self.tokenizer.batch_decode(
330
+ gen[0][len(toks[0]):],
331
+ skip_special_tokens=True
332
+ )
333
+
334
+ response = "".join(decoded).strip()
335
+ return response, history + [[message, response]]
336
+
337
+ except Exception as e:
338
+ return f"❌ Chat error: {str(e)}", history
339
+
340
+ def chat_stream(self, message, history):
341
+ """Streaming chat functionality"""
342
+ if self.model is None or self.tokenizer is None:
343
+ yield "⚠️ Please load a model first!"
344
+ return
345
+
346
+ try:
347
+ # Build conversation history
348
+ conversation = []
349
+ for msg in history:
350
+ if isinstance(msg, dict) and "role" in msg and "content" in msg:
351
+ conversation.append(msg)
352
+ elif isinstance(msg, list) and len(msg) == 2:
353
+ conversation.append({"role": "user", "content": msg[0]})
354
+ if msg[1]:
355
+ conversation.append({"role": "assistant", "content": msg[1]})
356
+
357
+ # Add current message
358
+ conversation.append({"role": "user", "content": message})
359
 
360
+ # Generate tokens
361
+ toks = self.tokenizer.apply_chat_template(
362
+ conversation=conversation,
363
+ add_generation_prompt=True,
364
+ return_tensors="pt"
365
+ )
366
 
367
+ # Stream response
368
+ from transformers import TextStreamer
369
+
370
+ class StreamingTextStreamer(TextStreamer):
371
  def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
372
  super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
373
+ self.current_text = ""
374
 
375
  def on_finalized_text(self, text: str, stream_end: bool = False):
376
+ self.current_text += text
377
+ yield self.current_text
378
 
379
+ streamer = StreamingTextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
380
 
381
+ # Generate with streaming
382
  gen = self.model.generate(
383
  toks.to(self.model.device),
384
  max_new_tokens=2048,
 
388
  streamer=streamer
389
  )
390
 
391
+ # Yield each chunk
392
+ for chunk in streamer.on_finalized_text("", False):
393
+ yield chunk
394
+
395
  except Exception as e:
396
+ yield f"❌ Chat error: {str(e)}"
397
 
398
  def get_new_model_card(original_card: ModelCard, original_model_id: str, new_repo_url: str) -> ModelCard:
399
  """Create new model card"""
 
586
 
587
  def bot(history):
588
  if history and history[-1]["role"] == "user":
589
+ # Start with empty assistant message
590
+ history.append({"role": "assistant", "content": ""})
591
+
592
+ # Get the full response
593
+ response, _ = processor.chat(history[-2]["content"], history[:-2])
594
+
595
+ # Update the assistant message with the full response
596
+ history[-1]["content"] = response
597
  return history
598
 
599
+ def bot_stream(history):
600
+ if history and history[-1]["role"] == "user":
601
+ # Start with empty assistant message
602
+ history.append({"role": "assistant", "content": ""})
603
+
604
+ # Get streaming response
605
+ for response_chunk in processor.chat_stream(history[-2]["content"], history[:-2]):
606
+ history[-1]["content"] = response_chunk
607
+ yield history
608
+
609
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
610
+ bot_stream, chatbot, chatbot
611
  )
612
 
613
  send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
614
+ bot_stream, chatbot, chatbot
615
  )
616
 
617
  clear.click(lambda: [], None, chatbot, queue=False)