Update app.py
Browse files
app.py
CHANGED
|
@@ -56,13 +56,10 @@ model_name = "tiiuae/falcon-7b-instruct"
|
|
| 56 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 57 |
|
| 58 |
|
| 59 |
-
class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
|
| 60 |
-
def __init__(self, model_name):
|
| 61 |
-
self.model_name = model_name
|
| 62 |
-
self.len_prompt = 0
|
| 63 |
|
|
|
|
| 64 |
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
| 65 |
-
|
| 66 |
input_str = json.dumps({
|
| 67 |
"inputs": prompt,
|
| 68 |
"parameters": {
|
|
@@ -73,7 +70,7 @@ class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
|
|
| 73 |
}
|
| 74 |
})
|
| 75 |
return input_str.encode('utf-8')
|
| 76 |
-
|
| 77 |
def transform_output(self, output: bytes) -> str:
|
| 78 |
response_json = output.decode('utf-8')
|
| 79 |
res = json.loads(response_json)
|
|
@@ -81,13 +78,8 @@ class ContentHandler(HuggingFaceEndpoint.CallbackHandler):
|
|
| 81 |
ans = ans[:ans.rfind("Human")].strip()
|
| 82 |
return ans
|
| 83 |
|
| 84 |
-
content_handler = ContentHandler(model_name=model_name)
|
| 85 |
-
|
| 86 |
def load_chain():
|
| 87 |
-
llm =
|
| 88 |
-
model_name=model_name,
|
| 89 |
-
content_handler=content_handler,
|
| 90 |
-
)
|
| 91 |
memory = ConversationBufferMemory()
|
| 92 |
chain = ConversationChain(llm=llm, memory=memory)
|
| 93 |
return chain
|
|
|
|
| 56 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
class CustomHuggingFaceEndpoint(HuggingFaceEndpoint):
|
| 61 |
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
| 62 |
+
len_prompt = len(prompt)
|
| 63 |
input_str = json.dumps({
|
| 64 |
"inputs": prompt,
|
| 65 |
"parameters": {
|
|
|
|
| 70 |
}
|
| 71 |
})
|
| 72 |
return input_str.encode('utf-8')
|
| 73 |
+
|
| 74 |
def transform_output(self, output: bytes) -> str:
|
| 75 |
response_json = output.decode('utf-8')
|
| 76 |
res = json.loads(response_json)
|
|
|
|
| 78 |
ans = ans[:ans.rfind("Human")].strip()
|
| 79 |
return ans
|
| 80 |
|
|
|
|
|
|
|
| 81 |
def load_chain():
|
| 82 |
+
llm = CustomHuggingFaceEndpoint(model_name=model_name)
|
|
|
|
|
|
|
|
|
|
| 83 |
memory = ConversationBufferMemory()
|
| 84 |
chain = ConversationChain(llm=llm, memory=memory)
|
| 85 |
return chain
|