File size: 1,170 Bytes
388fd6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env bash
set -euo pipefail

ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
TORCH_FLAVOR="${1:-default}"

python3 -m venv "$ROOT/.venv"
"$ROOT/.venv/bin/python" -m pip install --upgrade pip setuptools wheel

case "$TORCH_FLAVOR" in
  cpu)
    "$ROOT/.venv/bin/python" -m pip install torch --index-url https://download.pytorch.org/whl/cpu
    ;;
  cu121|cu124|cu126|cu128)
    "$ROOT/.venv/bin/python" -m pip install torch --index-url "https://download.pytorch.org/whl/$TORCH_FLAVOR"
    ;;
  default)
    "$ROOT/.venv/bin/python" -m pip install torch
    ;;
  *)
    echo "Unsupported torch flavor: $TORCH_FLAVOR" >&2
    echo "Use one of: default cpu cu121 cu124 cu126 cu128" >&2
    exit 2
    ;;
esac

"$ROOT/.venv/bin/python" -m pip install -e "$ROOT/code/Taotern_SSM"
"$ROOT/.venv/bin/python" -m pip install -e "$ROOT/code/TaoTrain"

echo
echo "Setup complete."
"$ROOT/.venv/bin/python" - <<'PY'
import torch
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("cuda device:", torch.cuda.get_device_name(0))
PY
echo
echo "Run fixed chat with:"
echo "  ./run_chat_fixed.sh"