jdesiree commited on
Commit
cd3695f
·
verified ·
1 Parent(s): 3010de8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -71
app.py CHANGED
@@ -415,85 +415,106 @@ class Phi3MiniEducationalLLM(Runnable):
415
  return f"[Error generating response: {str(e)}]"
416
 
417
  def stream_generate(self, input: Input, config=None):
418
- """Streaming generation using TextIteratorStreamer"""
419
- start_stream_time = time.perf_counter()
420
- current_time = datetime.now()
421
- logger.info("Starting stream_generate with TextIteratorStreamer...")
422
-
423
- # Handle both string and dict inputs
424
- if isinstance(input, dict):
425
- prompt = input.get('input', str(input))
426
- else:
427
- prompt = str(input)
428
-
429
- try:
430
- # Format using Phi-3 chat template
431
- text = self._format_chat_template(prompt)
432
 
433
- inputs = self.tokenizer(
434
- text,
435
- return_tensors="pt",
436
- padding=True,
437
- truncation=True,
438
- max_length=3072
439
- )
440
 
441
- # Move inputs to model device
442
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
443
 
444
- # Initialize TextIteratorStreamer
445
- streamer = TextIteratorStreamer(
446
- self.tokenizer,
447
- skip_prompt=True,
448
- skip_special_tokens=True
449
- )
450
 
451
- # Generation parameters
452
- generation_kwargs = {
453
- **inputs,
454
- "max_new_tokens": 800,
455
- "do_sample": True,
456
- "temperature": 0.7,
457
- "top_p": 0.9,
458
- "top_k": 50,
459
- "repetition_penalty": 1.1,
460
- "pad_token_id": self.tokenizer.eos_token_id,
461
- "streamer": streamer,
462
- "use_cache": True
463
- }
464
 
465
- # Start generation in a separate thread
466
- generation_thread = threading.Thread(
467
- target=self.model.generate,
468
- kwargs=generation_kwargs
469
- )
470
- generation_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- # Yield tokens as they become available
473
- generated_text = ""
474
- try:
475
- for new_text in streamer:
476
- if new_text: # Only yield non-empty strings
477
- generated_text += new_text
478
- yield generated_text
479
- except Exception as e:
480
- logger.error(f"Error in streaming iteration: {e}")
481
- yield f"[Streaming error: {str(e)}]"
482
 
483
- # Wait for generation to complete
484
- generation_thread.join()
485
-
486
- end_stream_time = time.perf_counter()
487
- stream_time = end_stream_time - start_stream_time
488
- log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
489
- logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
490
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  except Exception as e:
492
- logger.error(f"Streaming generation error: {e}")
493
- end_stream_time = time.perf_counter()
494
- stream_time = end_stream_time - start_stream_time
495
- log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
496
- yield f"[Error in streaming generation: {str(e)}]"
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
  @property
499
  def InputType(self) -> Type[Input]:
 
415
  return f"[Error generating response: {str(e)}]"
416
 
417
  def stream_generate(self, input: Input, config=None):
418
+ """Streaming generation using TextIteratorStreamer with loop detection and early escape."""
419
+ start_stream_time = time.perf_counter()
420
+ current_time = datetime.now()
421
+ logger.info("Starting stream_generate with TextIteratorStreamer and loop detection...")
 
 
 
 
 
 
 
 
 
 
422
 
423
+ if isinstance(input, dict):
424
+ prompt = input.get('input', str(input))
425
+ else:
426
+ prompt = str(input)
 
 
 
427
 
428
+ try:
429
+ # Format using Phi-3 chat template
430
+ text = self._format_chat_template(prompt)
431
+
432
+ inputs = self.tokenizer(
433
+ text,
434
+ return_tensors="pt",
435
+ padding=True,
436
+ truncation=True,
437
+ max_length=3072
438
+ )
439
 
440
+ # Move inputs to model device
441
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
 
 
 
442
 
443
+ # Initialize TextIteratorStreamer
444
+ streamer = TextIteratorStreamer(
445
+ self.tokenizer,
446
+ skip_prompt=True,
447
+ skip_special_tokens=True
448
+ )
 
 
 
 
 
 
 
449
 
450
+ # Generation parameters
451
+ generation_kwargs = {
452
+ **inputs,
453
+ "max_new_tokens": 800,
454
+ "do_sample": True,
455
+ "temperature": 0.7,
456
+ "top_p": 0.9,
457
+ "top_k": 50,
458
+ "repetition_penalty": 1.2, # Slightly stronger to help with loop prevention
459
+ "pad_token_id": self.tokenizer.eos_token_id,
460
+ "streamer": streamer,
461
+ "use_cache": True
462
+ }
463
+
464
+ # Start generation in background
465
+ generation_thread = threading.Thread(
466
+ target=self.model.generate,
467
+ kwargs=generation_kwargs
468
+ )
469
+ generation_thread.start()
470
 
471
+ # Track outputs
472
+ generated_text = ""
473
+ token_history = []
474
+ loop_window = 20 # Number of tokens to compare
475
+ loop_threshold = 3 # Allow N repetitions before aborting
 
 
 
 
 
476
 
477
+ try:
478
+ for new_text in streamer:
479
+ if not new_text:
480
+ continue
481
+
482
+ generated_text += new_text
483
+
484
+ # Tokenize and track
485
+ tokens = self.tokenizer.tokenize(new_text)
486
+ token_history.extend(tokens)
487
+
488
+ # Check for loops
489
+ if len(token_history) >= 2 * loop_window:
490
+ recent = token_history[-loop_window:]
491
+ prev = token_history[-2*loop_window:-loop_window]
492
+ overlap = sum(1 for r, p in zip(recent, prev) if r == p)
493
+
494
+ if overlap >= loop_threshold:
495
+ logger.warning(f"Looping detected (overlap: {overlap}/{loop_window}). Aborting generation.")
496
+ yield "[Looping detected — generation stopped early]"
497
+ break
498
+
499
+ yield generated_text
500
  except Exception as e:
501
+ logger.error(f"Error in streaming iteration: {e}")
502
+ yield f"[Streaming error: {str(e)}]"
503
+
504
+ generation_thread.join()
505
+
506
+ end_stream_time = time.perf_counter()
507
+ stream_time = end_stream_time - start_stream_time
508
+ log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
509
+ logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
510
+
511
+ except Exception as e:
512
+ logger.error(f"Streaming generation error: {e}")
513
+ end_stream_time = time.perf_counter()
514
+ stream_time = end_stream_time - start_stream_time
515
+ log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
516
+ yield f"[Error in streaming generation: {str(e)}]"
517
+
518
 
519
  @property
520
  def InputType(self) -> Type[Input]: