jdesiree commited on
Commit
54d8c57
·
verified ·
1 Parent(s): 22e0558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -429,7 +429,13 @@ class Phi3MiniEducationalLLM(Runnable):
429
  prompt = str(input)
430
 
431
  try:
432
- # Format using Phi-3 chat template
 
 
 
 
 
 
433
  text = self._format_chat_template(prompt)
434
 
435
  inputs = self.tokenizer(
@@ -440,8 +446,8 @@ class Phi3MiniEducationalLLM(Runnable):
440
  max_length=3072
441
  )
442
 
443
- # Move inputs to model device
444
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
445
 
446
  # Initialize TextIteratorStreamer
447
  streamer = TextIteratorStreamer(
@@ -458,7 +464,7 @@ class Phi3MiniEducationalLLM(Runnable):
458
  "temperature": 0.7,
459
  "top_p": 0.9,
460
  "top_k": 50,
461
- "repetition_penalty": 1.2, # Slightly stronger to help with loop prevention
462
  "pad_token_id": self.tokenizer.eos_token_id,
463
  "streamer": streamer,
464
  "use_cache": True
@@ -466,7 +472,7 @@ class Phi3MiniEducationalLLM(Runnable):
466
 
467
  # Start generation in background
468
  generation_thread = threading.Thread(
469
- target=self.model.generate,
470
  kwargs=generation_kwargs
471
  )
472
  generation_thread.start()
@@ -474,43 +480,43 @@ class Phi3MiniEducationalLLM(Runnable):
474
  # Track outputs
475
  generated_text = ""
476
  token_history = []
477
- loop_window = 20 # Number of tokens to compare
478
- loop_threshold = 3 # Allow N repetitions before aborting
479
-
480
  try:
481
  for new_text in streamer:
482
  if not new_text:
483
  continue
484
-
485
  generated_text += new_text
486
-
487
  # Tokenize and track
488
  tokens = self.tokenizer.tokenize(new_text)
489
  token_history.extend(tokens)
490
-
491
  # Check for loops
492
  if len(token_history) >= 2 * loop_window:
493
  recent = token_history[-loop_window:]
494
  prev = token_history[-2*loop_window:-loop_window]
495
  overlap = sum(1 for r, p in zip(recent, prev) if r == p)
496
-
497
  if overlap >= loop_threshold:
498
  logger.warning(f"Looping detected (overlap: {overlap}/{loop_window}). Aborting generation.")
499
  yield "[Looping detected — generation stopped early]"
500
  break
501
-
502
  yield generated_text
503
  except Exception as e:
504
  logger.error(f"Error in streaming iteration: {e}")
505
  yield f"[Streaming error: {str(e)}]"
506
-
507
  generation_thread.join()
508
-
509
  end_stream_time = time.perf_counter()
510
  stream_time = end_stream_time - start_stream_time
511
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
512
  logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
513
-
514
  except Exception as e:
515
  logger.error(f"Streaming generation error: {e}")
516
  end_stream_time = time.perf_counter()
 
429
  prompt = str(input)
430
 
431
  try:
432
+ # Load model inside GPU context
433
+ model = self._load_model_if_needed()
434
+
435
+ # Clear GPU cache
436
+ if torch.cuda.is_available():
437
+ torch.cuda.empty_cache()
438
+
439
  text = self._format_chat_template(prompt)
440
 
441
  inputs = self.tokenizer(
 
446
  max_length=3072
447
  )
448
 
449
+ # Move inputs to model device - now model is not None
450
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
451
 
452
  # Initialize TextIteratorStreamer
453
  streamer = TextIteratorStreamer(
 
464
  "temperature": 0.7,
465
  "top_p": 0.9,
466
  "top_k": 50,
467
+ "repetition_penalty": 1.2,
468
  "pad_token_id": self.tokenizer.eos_token_id,
469
  "streamer": streamer,
470
  "use_cache": True
 
472
 
473
  # Start generation in background
474
  generation_thread = threading.Thread(
475
+ target=model.generate, # Use the loaded model
476
  kwargs=generation_kwargs
477
  )
478
  generation_thread.start()
 
480
  # Track outputs
481
  generated_text = ""
482
  token_history = []
483
+ loop_window = 20
484
+ loop_threshold = 3
485
+
486
  try:
487
  for new_text in streamer:
488
  if not new_text:
489
  continue
490
+
491
  generated_text += new_text
492
+
493
  # Tokenize and track
494
  tokens = self.tokenizer.tokenize(new_text)
495
  token_history.extend(tokens)
496
+
497
  # Check for loops
498
  if len(token_history) >= 2 * loop_window:
499
  recent = token_history[-loop_window:]
500
  prev = token_history[-2*loop_window:-loop_window]
501
  overlap = sum(1 for r, p in zip(recent, prev) if r == p)
502
+
503
  if overlap >= loop_threshold:
504
  logger.warning(f"Looping detected (overlap: {overlap}/{loop_window}). Aborting generation.")
505
  yield "[Looping detected — generation stopped early]"
506
  break
507
+
508
  yield generated_text
509
  except Exception as e:
510
  logger.error(f"Error in streaming iteration: {e}")
511
  yield f"[Streaming error: {str(e)}]"
512
+
513
  generation_thread.join()
514
+
515
  end_stream_time = time.perf_counter()
516
  stream_time = end_stream_time - start_stream_time
517
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
518
  logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
519
+
520
  except Exception as e:
521
  logger.error(f"Streaming generation error: {e}")
522
  end_stream_time = time.perf_counter()