Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ from transformers import TextIteratorStreamer
|
|
| 3 |
from threading import Thread
|
| 4 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 5 |
import torch
|
|
|
|
| 6 |
import os
|
| 7 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -16,6 +17,7 @@ class StopOnTokens(StoppingCriteria):
|
|
| 16 |
if input_ids[0][-1] == stop_id:
|
| 17 |
return True
|
| 18 |
return False
|
|
|
|
| 19 |
def predict(message, history):
|
| 20 |
history_transformer_format = history + [[message, ""]]
|
| 21 |
stop = StopOnTokens()
|
|
|
|
| 3 |
from threading import Thread
|
| 4 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 5 |
import torch
|
| 6 |
+
import spaces
|
| 7 |
import os
|
| 8 |
model_name = "microsoft/Phi-3-medium-128k-instruct"
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 17 |
if input_ids[0][-1] == stop_id:
|
| 18 |
return True
|
| 19 |
return False
|
| 20 |
+
@spaces.GPU()
|
| 21 |
def predict(message, history):
|
| 22 |
history_transformer_format = history + [[message, ""]]
|
| 23 |
stop = StopOnTokens()
|