Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
| 1 |
# coding=utf-8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from src.logger import LoggerFactory
|
| 3 |
from src.prompt_concat import GetManualTestSamples, CreateTestDataset
|
| 4 |
from src.utils import decode_csv_to_json, load_json, save_to_json
|
|
@@ -23,12 +28,18 @@ import spaces
|
|
| 23 |
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
| 24 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
trust_remote_code=True)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
| 33 |
# warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 34 |
|
|
|
|
| 1 |
# coding=utf-8
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
from typing import Union
|
| 6 |
+
from pathlib import Path
|
| 7 |
from src.logger import LoggerFactory
|
| 8 |
from src.prompt_concat import GetManualTestSamples, CreateTestDataset
|
| 9 |
from src.utils import decode_csv_to_json, load_json, save_to_json
|
|
|
|
| 28 |
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
| 29 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 30 |
|
| 31 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
|
| 32 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
| 33 |
+
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto",
|
| 36 |
trust_remote_code=True)
|
| 37 |
|
| 38 |
+
character_path = "./character"
|
| 39 |
+
|
| 40 |
+
def _resolve_path(path: Union[str, Path]) -> Path:
|
| 41 |
+
return Path(path).expanduser().resolve()
|
| 42 |
+
|
| 43 |
# logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
| 44 |
# warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 45 |
|