jdesiree commited on
Commit
5e4b8e6
·
verified ·
1 Parent(s): f34d795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -15
app.py CHANGED
@@ -1,34 +1,48 @@
1
- import spaces
2
- import gradio as gr
3
- from graph_tool import generate_plot
4
  import os
5
- import platform
6
- from dotenv import load_dotenv
7
- import logging
8
  import re
9
  import json
 
 
 
10
  import threading
 
 
 
 
11
  from datetime import datetime
 
12
  from typing import Annotated, Sequence, TypedDict, List, Optional, Any, Type
 
13
  from pydantic import BaseModel, Field
14
 
 
 
 
15
  # LangGraph imports
16
  from langgraph.graph import StateGraph, START, END
17
  from langgraph.graph.message import add_messages
18
  from langgraph.checkpoint.memory import MemorySaver
19
  from langgraph.prebuilt import ToolNode
20
 
21
- # Updated LangChain imports
22
  from langchain_core.tools import tool
23
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage, BaseMessage
24
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
25
  from langchain_core.runnables import Runnable
26
  from langchain_core.runnables.utils import Input, Output
27
 
28
- from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM, BitsAndBytesConfig
29
- import torch
30
- import time
31
- import warnings
 
 
 
 
 
 
 
 
32
 
33
  # Updated environment variables
34
  os.environ['HF_HOME'] = '/tmp/huggingface'
@@ -405,6 +419,18 @@ class Phi3MiniEducationalLLM(Runnable):
405
  # Fallback to manual Phi-3 format
406
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
407
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  @spaces.GPU(duration=180)
409
  def invoke(self, input: Input, config=None) -> Output:
410
  """Main invoke method optimized for 4-bit quantized Phi-3-mini"""
@@ -432,16 +458,17 @@ class Phi3MiniEducationalLLM(Runnable):
432
 
433
  # FIX: Proper tokenization with error handling
434
  try:
 
435
  inputs = self.tokenizer(
436
  text,
437
  return_tensors="pt",
438
  padding=True,
439
  truncation=True,
440
- max_length=4096
441
  )
442
 
443
  # Ensure inputs are properly formatted
444
- if not hasattr(inputs, 'input_ids'):
445
  logger.error("Tokenizer did not return input_ids")
446
  return "I encountered an error processing your request. Please try again."
447
 
@@ -462,7 +489,7 @@ class Phi3MiniEducationalLLM(Runnable):
462
  outputs = model.generate(
463
  input_ids=inputs['input_ids'],
464
  attention_mask=inputs.get('attention_mask', None),
465
- max_new_tokens=1200,
466
  do_sample=True,
467
  temperature=0.7,
468
  top_p=0.9,
@@ -470,7 +497,8 @@ class Phi3MiniEducationalLLM(Runnable):
470
  repetition_penalty=1.1,
471
  pad_token_id=self.tokenizer.eos_token_id,
472
  use_cache=False,
473
- past_key_values=None
 
474
  )
475
  except Exception as generation_error:
476
  logger.error(f"Generation error: {generation_error}")
@@ -480,6 +508,13 @@ class Phi3MiniEducationalLLM(Runnable):
480
  try:
481
  new_tokens = outputs[0][len(inputs['input_ids'][0]):]
482
  result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
483
  except Exception as decode_error:
484
  logger.error(f"Decoding error: {decode_error}")
485
  return "I encountered an error processing the response. Please try again."
@@ -489,6 +524,7 @@ class Phi3MiniEducationalLLM(Runnable):
489
  log_metric(f"LLM Invoke time (4-bit): {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
490
 
491
  return result if result else "I'm still learning how to respond to that properly."
 
492
 
493
  except Exception as e:
494
  logger.error(f"Generation error with 4-bit model: {e}")
 
 
 
 
1
  import os
 
 
 
2
  import re
3
  import json
4
+ import time
5
+ import torch
6
+ import gradio as gr
7
  import threading
8
+ import logging
9
+ import platform
10
+ import warnings
11
+
12
  from datetime import datetime
13
+ from dotenv import load_dotenv
14
  from typing import Annotated, Sequence, TypedDict, List, Optional, Any, Type
15
+
16
  from pydantic import BaseModel, Field
17
 
18
+ # Gradio Spaces decorator (for @spaces.GPU)
19
+ import spaces
20
+
21
  # LangGraph imports
22
  from langgraph.graph import StateGraph, START, END
23
  from langgraph.graph.message import add_messages
24
  from langgraph.checkpoint.memory import MemorySaver
25
  from langgraph.prebuilt import ToolNode
26
 
27
+ # LangChain Core imports
28
  from langchain_core.tools import tool
29
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage, BaseMessage
30
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
31
  from langchain_core.runnables import Runnable
32
  from langchain_core.runnables.utils import Input, Output
33
 
34
+ # Transformers imports
35
+ from transformers import (
36
+ AutoTokenizer,
37
+ AutoModelForCausalLM,
38
+ TextIteratorStreamer,
39
+ StoppingCriteria,
40
+ StoppingCriteriaList,
41
+ BitsAndBytesConfig,
42
+ )
43
+
44
+ from graph_tool import generate_plot
45
+
46
 
47
  # Updated environment variables
48
  os.environ['HF_HOME'] = '/tmp/huggingface'
 
419
  # Fallback to manual Phi-3 format
420
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
421
 
422
+ class StopOnSequence(StoppingCriteria):
423
+ def __init__(self, tokenizer, stop_sequence):
424
+ self.tokenizer = tokenizer
425
+ self.stop_sequence = tokenizer.encode(stop_sequence, add_special_tokens=False)
426
+
427
+ def __call__(self, input_ids, scores, **kwargs):
428
+ if input_ids[0, -len(self.stop_sequence):].tolist() == self.stop_sequence:
429
+ return True
430
+ return False
431
+
432
+ stop_criteria = StoppingCriteriaList([StopOnSequence(self.tokenizer, "User:")])
433
+
434
  @spaces.GPU(duration=180)
435
  def invoke(self, input: Input, config=None) -> Output:
436
  """Main invoke method optimized for 4-bit quantized Phi-3-mini"""
 
458
 
459
  # FIX: Proper tokenization with error handling
460
  try:
461
+ max_input_length = 4096 - 300
462
  inputs = self.tokenizer(
463
  text,
464
  return_tensors="pt",
465
  padding=True,
466
  truncation=True,
467
+ max_length=max_input_length
468
  )
469
 
470
  # Ensure inputs are properly formatted
471
+ if 'input_ids' not in inputs:
472
  logger.error("Tokenizer did not return input_ids")
473
  return "I encountered an error processing your request. Please try again."
474
 
 
489
  outputs = model.generate(
490
  input_ids=inputs['input_ids'],
491
  attention_mask=inputs.get('attention_mask', None),
492
+ max_new_tokens=300,
493
  do_sample=True,
494
  temperature=0.7,
495
  top_p=0.9,
 
497
  repetition_penalty=1.1,
498
  pad_token_id=self.tokenizer.eos_token_id,
499
  use_cache=False,
500
+ past_key_values=None,
501
+ stopping_criteria=stop_criteria
502
  )
503
  except Exception as generation_error:
504
  logger.error(f"Generation error: {generation_error}")
 
508
  try:
509
  new_tokens = outputs[0][len(inputs['input_ids'][0]):]
510
  result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
511
+
512
+ # Soft stop cleanup
513
+ for stop_word in ["User:", "\n\n", "###"]:
514
+ if stop_word in result:
515
+ result = result.split(stop_word)[0].strip()
516
+ break
517
+
518
  except Exception as decode_error:
519
  logger.error(f"Decoding error: {decode_error}")
520
  return "I encountered an error processing the response. Please try again."
 
524
  log_metric(f"LLM Invoke time (4-bit): {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
525
 
526
  return result if result else "I'm still learning how to respond to that properly."
527
+
528
 
529
  except Exception as e:
530
  logger.error(f"Generation error with 4-bit model: {e}")