nvinay1803 commited on
Commit
c465bdd
·
verified ·
1 Parent(s): f85115e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -12
app.py CHANGED
@@ -56,13 +56,10 @@ model_name = "tiiuae/falcon-7b-instruct"
56
  tokenizer = AutoTokenizer.from_pretrained(model_name)
57
 
58
 
59
- class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
60
- def __init__(self, model_name):
61
- self.model_name = model_name
62
- self.len_prompt = 0
63
 
 
64
  def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
65
- self.len_prompt = len(prompt)
66
  input_str = json.dumps({
67
  "inputs": prompt,
68
  "parameters": {
@@ -73,7 +70,7 @@ class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
73
  }
74
  })
75
  return input_str.encode('utf-8')
76
-
77
  def transform_output(self, output: bytes) -> str:
78
  response_json = output.decode('utf-8')
79
  res = json.loads(response_json)
@@ -81,13 +78,8 @@ class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
81
  ans = ans[:ans.rfind("Human")].strip()
82
  return ans
83
 
84
- content_handler = ContentHandler(model_name=model_name)
85
-
86
  def load_chain():
87
- llm = HuggingFaceEndpoint(
88
- model_name=model_name,
89
- content_handler=content_handler,
90
- )
91
  memory = ConversationBufferMemory()
92
  chain = ConversationChain(llm=llm, memory=memory)
93
  return chain
 
56
  tokenizer = AutoTokenizer.from_pretrained(model_name)
57
 
58
 
 
 
 
 
59
 
60
+ class CustomHuggingFaceEndpoint(HuggingFaceEndpoint):
61
  def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
62
+ len_prompt = len(prompt)
63
  input_str = json.dumps({
64
  "inputs": prompt,
65
  "parameters": {
 
70
  }
71
  })
72
  return input_str.encode('utf-8')
73
+
74
  def transform_output(self, output: bytes) -> str:
75
  response_json = output.decode('utf-8')
76
  res = json.loads(response_json)
 
78
  ans = ans[:ans.rfind("Human")].strip()
79
  return ans
80
 
 
 
81
  def load_chain():
82
+ llm = CustomHuggingFaceEndpoint(model_name=model_name)
 
 
 
83
  memory = ConversationBufferMemory()
84
  chain = ConversationChain(llm=llm, memory=memory)
85
  return chain