Upload 8 files
Browse files- Dockerfile +9 -24
- Dockerfile copy +44 -0
- Dockerfile3 +25 -0
- app.py +147 -0
- hfdoc +43 -0
- main.py +160 -0
- memo +34 -0
- model_loader.py +181 -0
Dockerfile
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
# ARG CUDA_IMAGE="12.1.0-devel-ubuntu22.04"
|
| 2 |
-
# FROM nvidia/cuda:${CUDA_IMAGE}
|
| 3 |
ARG CUDA_IMAGE="12.5.0-devel-ubuntu22.04"
|
| 4 |
FROM nvidia/cuda:${CUDA_IMAGE}
|
|
|
|
| 5 |
# We need to set the host to 0.0.0.0 to allow outside access
|
| 6 |
ENV HOST 0.0.0.0
|
| 7 |
|
|
@@ -9,37 +8,23 @@ RUN apt-get update && apt-get upgrade -y \
|
|
| 9 |
&& apt-get install -y git build-essential \
|
| 10 |
python3 python3-pip gcc wget \
|
| 11 |
ocl-icd-opencl-dev opencl-headers clinfo \
|
| 12 |
-
libclblast-dev libopenblas-dev
|
| 13 |
&& mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
|
| 14 |
|
| 15 |
-
|
| 16 |
-
RUN useradd -m -u 1000 gee
|
| 17 |
-
USER gee
|
| 18 |
-
ENV HOME=/home/gee \
|
| 19 |
-
PATH=/home/gee/.local/bin:$PATH
|
| 20 |
-
|
| 21 |
-
ENV HF_HOME=$HOME/app/.cache/huggingface
|
| 22 |
-
WORKDIR $HOME/app
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
COPY --chown=gee . $HOME/app
|
| 26 |
|
|
|
|
| 27 |
|
| 28 |
# setting build related env vars
|
| 29 |
ENV CUDA_DOCKER_ARCH=all
|
| 30 |
ENV GGML_CUDA=1
|
| 31 |
|
| 32 |
# Install depencencies
|
| 33 |
-
RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings starlette-context
|
| 34 |
-
|
| 35 |
-
# Install llama-cpp-python (build with cuda)
|
| 36 |
-
RUN pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu125
|
| 37 |
|
| 38 |
# Install llama-cpp-python (build with cuda)
|
| 39 |
-
|
| 40 |
-
#
|
| 41 |
-
# CMD python3 -m llama_cpp.server
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
CMD ["python3","llm.py"]
|
|
|
|
|
|
|
|
|
|
| 1 |
ARG CUDA_IMAGE="12.5.0-devel-ubuntu22.04"
|
| 2 |
FROM nvidia/cuda:${CUDA_IMAGE}
|
| 3 |
+
|
| 4 |
# We need to set the host to 0.0.0.0 to allow outside access
|
| 5 |
ENV HOST 0.0.0.0
|
| 6 |
|
|
|
|
| 8 |
&& apt-get install -y git build-essential \
|
| 9 |
python3 python3-pip gcc wget \
|
| 10 |
ocl-icd-opencl-dev opencl-headers clinfo \
|
| 11 |
+
libclblast-dev libopenblas-dev \
|
| 12 |
&& mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
|
| 13 |
|
| 14 |
+
COPY . .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
RUN nvidia-smi
|
| 17 |
|
| 18 |
# setting build related env vars
|
| 19 |
ENV CUDA_DOCKER_ARCH=all
|
| 20 |
ENV GGML_CUDA=1
|
| 21 |
|
| 22 |
# Install depencencies
|
| 23 |
+
RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings starlette-context
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Install llama-cpp-python (build with cuda)
|
| 26 |
+
RUN CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
| 27 |
+
# --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
|
|
|
|
| 28 |
|
| 29 |
+
# Run the server
|
| 30 |
+
CMD python3 -m llama_cpp.server
|
|
|
Dockerfile copy
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ARG MODEL_ID="Qwen/Qwen3-8B"
|
| 9 |
+
ENV MODEL_ID=${MODEL_ID}
|
| 10 |
+
|
| 11 |
+
# 量子化設定のデフォルト (Spaceのenvで上書き可能)
|
| 12 |
+
ENV LOAD_IN_4BIT="true"
|
| 13 |
+
ENV LOAD_IN_8BIT="false"
|
| 14 |
+
|
| 15 |
+
# 依存関係のインストール
|
| 16 |
+
COPY requirements.txt requirements.txt
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# コードのコピー
|
| 20 |
+
COPY main.py .
|
| 21 |
+
COPY model_loader.py .
|
| 22 |
+
COPY schemas.py .
|
| 23 |
+
|
| 24 |
+
# ちゃんと書き込みできるユーザー設定しないとエラーが出るっぽい。
|
| 25 |
+
RUN useradd -m -u 1000 gee
|
| 26 |
+
USER gee
|
| 27 |
+
ENV HOME=/home/gee \
|
| 28 |
+
PATH=/home/gee/.local/bin:$PATH
|
| 29 |
+
|
| 30 |
+
ENV HF_HOME=$HOME/app/.cache/huggingface
|
| 31 |
+
WORKDIR $HOME/app
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
COPY --chown=gee . $HOME/app
|
| 35 |
+
|
| 36 |
+
# ポートの公開 (Uvicornがリッスンするポート)
|
| 37 |
+
# Hugging Face Spacesでは通常7860がデフォルトだが、APIサーバーなら8000でも良い。
|
| 38 |
+
# README.mdのapp_portと合わせる。
|
| 39 |
+
EXPOSE 8000
|
| 40 |
+
|
| 41 |
+
# アプリケーションの起動コマンド
|
| 42 |
+
|
| 43 |
+
# CMD uvicorn main:app --host 0.0.0.0 --port ${PORT:-8000} --workers 1
|
| 44 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
Dockerfile3
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ベースイメージはCUDA 11.8
|
| 2 |
+
FROM docker.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04@sha256:8f9dd0d09d3ad3900357a1cf7f887888b5b74056636cd6ef03c160c3cd4b1d95
|
| 3 |
+
|
| 4 |
+
# Pythonやpipなどの基本的なツールをインストール
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
python3 \
|
| 7 |
+
python3-pip \
|
| 8 |
+
git \
|
| 9 |
+
# llama-cpp-pythonのビルドに必要な場合がある(ホイールが見つからない場合など)
|
| 10 |
+
# build-essential cmake \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
# (推奨) モデルダウンロード用のディレクトリ作成と権限設定
|
| 13 |
+
RUN mkdir /models && chmod 777 /models
|
| 14 |
+
VOLUME /models
|
| 15 |
+
RUN pip install llama-cpp-python[server] \
|
| 16 |
+
--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu118
|
| 17 |
+
RUN pip install huggingface-hub
|
| 18 |
+
|
| 19 |
+
ENV MODEL_ID="unsloth/Qwen3-30B-A3B-GGUF"
|
| 20 |
+
|
| 21 |
+
COPY . .
|
| 22 |
+
RUN ln -s /usr/bin/python3 /usr/bin/python
|
| 23 |
+
|
| 24 |
+
# python -m llama_cpp.server --model /path/to/your/model.gguf --host 0.0.0.0 --port 8000
|
| 25 |
+
CMD ["python", "-m", "llama_cpp.server", "--hf_model_repo_id", "unsloth/Qwen3-30B-A3B-GGUF" ,"--model", "*Q4_0.gguf", "--host", "0.0.0.0", "--port", "8000"]
|
app.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import subprocess
|
| 4 |
+
import os
|
| 5 |
+
import time # timeモジュールは直接使っていませんが、コメントアウトされた部分で使われる可能性あり
|
| 6 |
+
|
| 7 |
+
# --- JGLUE評価タスク実行関数 ---
|
| 8 |
+
def run_jglue_evaluation_task(model_name, other_param, max_samples, progress: gr.Progress):
|
| 9 |
+
log_output = ""
|
| 10 |
+
# 評価スクリプトのパス (app.py と同じディレクトリにあると仮定)
|
| 11 |
+
script_path = "jglue_script.py"
|
| 12 |
+
task_name_dummy = "marc_ja" # UIからタスク名も入力できるようにするのが理想
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
# コマンドリストの作成
|
| 16 |
+
command = [
|
| 17 |
+
"python", script_path,
|
| 18 |
+
"--model_name_or_path", str(model_name),
|
| 19 |
+
"--task_name", str(task_name_dummy), # ここはUIから取得するように変更推奨
|
| 20 |
+
# "--dataset_path", str(dataset_path), # JGLUEスクリプトがHubからロードするなら不要な場合も
|
| 21 |
+
"--output_dir", "./evaluation_results", # Space内に結果保存用ディレクトリ
|
| 22 |
+
"--eval_batch_size", "8", # 例: UIから変更可能にしても良い
|
| 23 |
+
"--max_seq_length", "128", # 例: UIから変更可能にしても良い
|
| 24 |
+
# "--other_param_for_script", str(other_param), # スクリプト側で受け取る引数名に合わせる
|
| 25 |
+
]
|
| 26 |
+
if max_samples is not None and int(max_samples) > 0:
|
| 27 |
+
command.extend(["--max_eval_samples", str(int(max_samples))])
|
| 28 |
+
|
| 29 |
+
log_output += f"実行コマンド: {' '.join(command)}\n\n"
|
| 30 |
+
# progress(0, desc="評価スクリプト準備中...")
|
| 31 |
+
yield log_output # 初期ログをすぐに表示
|
| 32 |
+
|
| 33 |
+
# subprocessの実行
|
| 34 |
+
process = subprocess.Popen(
|
| 35 |
+
command,
|
| 36 |
+
stdout=subprocess.PIPE,
|
| 37 |
+
stderr=subprocess.STDOUT, # 標準エラーも標準出力にマージ
|
| 38 |
+
text=True,
|
| 39 |
+
bufsize=1, # 1行ずつのバッファリング
|
| 40 |
+
universal_newlines=True,
|
| 41 |
+
encoding='utf-8' # 明示的にエンコーディング指定
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# リアルタイムでログを読み取り、進捗を更新
|
| 45 |
+
line_count = 0
|
| 46 |
+
# max_expected_lines = 200 # これはあくまで目安なので、より動的な進捗更新が望ましい
|
| 47 |
+
|
| 48 |
+
# iterの第二引数を空文字列にすることで、プロセスが終了するまで読み続ける
|
| 49 |
+
for line in iter(process.stdout.readline, ''):
|
| 50 |
+
if not line: # 空行が連続する場合の対策 (あまりないはずだが)
|
| 51 |
+
# プロセスがまだ生きているか確認 (オプション)
|
| 52 |
+
# if process.poll() is not None: break
|
| 53 |
+
continue
|
| 54 |
+
print(line, end='', flush=True) # Dockerのログにもリアルタイムで出力
|
| 55 |
+
log_output += line
|
| 56 |
+
line_count += 1
|
| 57 |
+
# 進捗バーの更新 (ここでは単純に1行ごとに更新するが、より意味のある更新が望ましい)
|
| 58 |
+
# 例えば、スクリプト側で "PROGRESS: 25%" のような文字列を出力し、それをパースするなど
|
| 59 |
+
progress(min(0.01 * line_count, 0.95), desc=f"評価実行中... (ログ {line_count}行目)") # 0.95で止めておき、完了時に1.0にする
|
| 60 |
+
yield log_output # ストリーミング出力でリアルタイムにUI更新
|
| 61 |
+
|
| 62 |
+
process.stdout.close()
|
| 63 |
+
return_code = process.wait() # プロセスの終了を待つ
|
| 64 |
+
|
| 65 |
+
if return_code == 0:
|
| 66 |
+
log_output += "\n\n評価が正常に完了しました。"
|
| 67 |
+
progress(1.0, desc="評価完了!")
|
| 68 |
+
else:
|
| 69 |
+
log_output += f"\n\n評価スクリプトがエラーコード {return_code} で終了しました。"
|
| 70 |
+
progress(1.0, desc="評価エラー")
|
| 71 |
+
|
| 72 |
+
yield log_output # 最終ログを送信
|
| 73 |
+
|
| 74 |
+
except FileNotFoundError:
|
| 75 |
+
log_output += f"\n\nエラー: 評価スクリプト '{script_path}' が見つかりません。"
|
| 76 |
+
progress(1.0, desc="スクリプトエラー")
|
| 77 |
+
yield log_output
|
| 78 |
+
except Exception as e:
|
| 79 |
+
log_output += f"\n\n予期せぬエラーが発生しました: {e}"
|
| 80 |
+
progress(1.0, desc="致命的エラー")
|
| 81 |
+
import traceback
|
| 82 |
+
log_output += "\n\n--- Traceback ---\n" + traceback.format_exc()
|
| 83 |
+
yield log_output
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# --- Gradioインターフェース ---
|
| 87 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 88 |
+
gr.Markdown("# JGLUE 評価プラットフォーム")
|
| 89 |
+
gr.Markdown("Dockerコンテナ上でJGLUE評価スクリプトを実行します。")
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=1):
|
| 93 |
+
gr.Markdown("### 評価設定")
|
| 94 |
+
model_name_input = gr.Textbox(
|
| 95 |
+
label="評価するモデル名またはパス",
|
| 96 |
+
placeholder="例: cl-tohoku/bert-base-japanese-whole-word-masking",
|
| 97 |
+
value="cl-tohoku/bert-base-japanese-whole-word-masking"
|
| 98 |
+
)
|
| 99 |
+
task_name_input = gr.Dropdown(
|
| 100 |
+
label="JGLUE タスク名",
|
| 101 |
+
choices=["marc_ja", "jsts", "jnli", "jcommonsense_qa"],
|
| 102 |
+
value="marc_ja"
|
| 103 |
+
)
|
| 104 |
+
other_param_input = gr.Slider(
|
| 105 |
+
label="その他のパラメータ (スクリプト側で解釈)",
|
| 106 |
+
minimum=1, maximum=10, value=5, step=1
|
| 107 |
+
)
|
| 108 |
+
max_samples_input = gr.Number(
|
| 109 |
+
label="評価サンプル数上限 (0または空で全件)",
|
| 110 |
+
value=100,
|
| 111 |
+
minimum=0, step=10, precision=0
|
| 112 |
+
)
|
| 113 |
+
submit_button = gr.Button("評価開始", variant="primary", icon="▶️")
|
| 114 |
+
|
| 115 |
+
with gr.Column(scale=2):
|
| 116 |
+
gr.Markdown("### 実行ログと結果")
|
| 117 |
+
# ProgressコンポーネントはUIに配置するだけで、clickのinputsには含めない
|
| 118 |
+
progress_component = gr.Progress()
|
| 119 |
+
output_log = gr.Textbox(
|
| 120 |
+
label="ログ出力エリア",
|
| 121 |
+
lines=20,
|
| 122 |
+
interactive=False,
|
| 123 |
+
max_lines=200,
|
| 124 |
+
show_copy_button=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# ボタンが押されたら評価関数を実行し、出力をストリーミング
|
| 128 |
+
submit_button.click(
|
| 129 |
+
fn=run_jglue_evaluation_task,
|
| 130 |
+
# inputs から progress_component を削除
|
| 131 |
+
inputs=[model_name_input, task_name_input, other_param_input, max_samples_input],
|
| 132 |
+
outputs=[output_log]
|
| 133 |
+
# progress引数はGradioが関数の型ヒントを見て自動的に対応するUIコンポーネントを渡してくれる
|
| 134 |
+
# この場合、fnの引数 progress: gr.Progress とUI上の progress_component が関連付けられる
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# --- Gradioアプリの起動 ---
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
print("DEBUG: app.py - Inside __main__ block. Attempting to launch Gradio app...", flush=True)
|
| 140 |
+
try:
|
| 141 |
+
# .queue() をつけることで、複数のリクエストや長時間実行タスクを処理しやすくなる
|
| 142 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
|
| 143 |
+
# debug=True をつけると、GradioやUvicornのより詳細なログが出る (開発中便利)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"DEBUG: app.py - Error during demo.launch(): {e}", flush=True)
|
| 146 |
+
import traceback
|
| 147 |
+
traceback.print_exc()
|
hfdoc
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ARG CUDA_IMAGE="12.1.0-devel-ubuntu22.04"
|
| 2 |
+
# FROM nvidia/cuda:${CUDA_IMAGE}
|
| 3 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 4 |
+
|
| 5 |
+
# We need to set the host to 0.0.0.0 to allow outside access
|
| 6 |
+
ENV HOST 0.0.0.0
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && apt-get upgrade -y \
|
| 9 |
+
&& apt-get install -y git build-essential \
|
| 10 |
+
python3 python3-pip gcc wget \
|
| 11 |
+
ocl-icd-opencl-dev opencl-headers clinfo \
|
| 12 |
+
libclblast-dev libopenblas-dev \
|
| 13 |
+
&& mkdir -p /etc/OpenCL/vendors && echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ちゃんと書き込みできるユーザー設定しないとエラーが出るっぽい。
|
| 17 |
+
RUN useradd -m -u 1000 gee
|
| 18 |
+
USER gee
|
| 19 |
+
ENV HOME=/home/gee \
|
| 20 |
+
PATH=/home/gee/.local/bin:$PATH
|
| 21 |
+
|
| 22 |
+
ENV HF_HOME=$HOME/app/.cache/huggingface
|
| 23 |
+
WORKDIR $HOME/app
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
COPY --chown=gee . $HOME/app
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# setting build related env vars
|
| 30 |
+
ENV CUDA_DOCKER_ARCH=all
|
| 31 |
+
ENV GGML_CUDA=1
|
| 32 |
+
|
| 33 |
+
# Install depencencies
|
| 34 |
+
RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings starlette-context huggingface_hub hf_xet
|
| 35 |
+
|
| 36 |
+
# Install llama-cpp-python (build with cuda)
|
| 37 |
+
RUN pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
|
| 38 |
+
|
| 39 |
+
# # Run the server
|
| 40 |
+
# CMD python3 -m llama_cpp.server
|
| 41 |
+
|
| 42 |
+
# python -m llama_cpp.server --model /path/to/your/model.gguf --host 0.0.0.0 --port 8000
|
| 43 |
+
CMD ["python3", "-W","ignore","-m", "llama_cpp.server", "--hf_model_repo_id", "unsloth/Qwen3-30B-A3B-GGUF" ,"--model", "*Q4_0.gguf", "--host", "0.0.0.0", "--port", "8000"]
|
main.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py
|
| 2 |
+
from fastapi import FastAPI, HTTPException, Request as FastAPIRequest
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
import uvicorn
|
| 5 |
+
import os
|
| 6 |
+
import uuid # id生成用 (schemasに移動しても良い)
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
# ローカルモジュールからのインポート
|
| 10 |
+
from model_loader import (
|
| 11 |
+
load_model,
|
| 12 |
+
generate_text,
|
| 13 |
+
MODEL_ID as LOADED_MODEL_ID,
|
| 14 |
+
) # MODEL_IDもインポート
|
| 15 |
+
import model_loader
|
| 16 |
+
from schemas import (
|
| 17 |
+
ChatCompletionRequest,
|
| 18 |
+
ChatCompletionResponse,
|
| 19 |
+
ChatCompletionMessage,
|
| 20 |
+
ChatCompletionResponseMessage,
|
| 21 |
+
ChatCompletionChoice,
|
| 22 |
+
# Usage,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
app = FastAPI(
|
| 27 |
+
title="OpenAI Compatible LLM API",
|
| 28 |
+
description=f"Provides an OpenAI-compatible API endpoint for the model: {os.environ.get('MODEL_ID', 'default_model_id_from_env')}",
|
| 29 |
+
version="0.1.0",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# --- イベントハンドラ ---
|
| 34 |
+
@app.on_event("startup")
|
| 35 |
+
async def startup_event():
|
| 36 |
+
"""
|
| 37 |
+
アプリケーション起動時にモデルをロードする。
|
| 38 |
+
"""
|
| 39 |
+
print("Application startup: Loading model...")
|
| 40 |
+
try:
|
| 41 |
+
load_model() # model_loader.py の関数を呼び出し
|
| 42 |
+
print(f"Model '{LOADED_MODEL_ID}' should be loaded now.")
|
| 43 |
+
except RuntimeError as e:
|
| 44 |
+
print(f"Fatal Error during application startup: {e}")
|
| 45 |
+
# ここでアプリを異常終了させるか、ヘルスチェックエンドポイントでエラーを返すようにする
|
| 46 |
+
# Uvicornの起動自体は成功してしまう可能性があるため注意
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# --- APIエンドポイント ---
|
| 50 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
| 51 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
| 52 |
+
"""
|
| 53 |
+
OpenAI互換のチャット補完エンドポイント。
|
| 54 |
+
"""
|
| 55 |
+
print(
|
| 56 |
+
f"Received request for model: {request.model} (Actual model: {LOADED_MODEL_ID})"
|
| 57 |
+
)
|
| 58 |
+
print(f"Messages: {request.messages}")
|
| 59 |
+
|
| 60 |
+
if (
|
| 61 |
+
model_loader.model is None or model_loader.tokenizer is None
|
| 62 |
+
): # model_loaderのグローバル変数をチェック
|
| 63 |
+
raise HTTPException(
|
| 64 |
+
status_code=503,
|
| 65 |
+
detail=f"Model '{LOADED_MODEL_ID}' is not available. Check server logs.",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 最後のユーザーメッセージをプロンプトとして使用 (より複雑な会話履歴の扱いは要検討)
|
| 69 |
+
# OpenAIのmessagesはリストなので、最後のユーザーメッセージを取り出すか、
|
| 70 |
+
# 全体を結合して1つのプロンプトにするかはモデルの期待する形式による。
|
| 71 |
+
# ここでは最後のユーザーメッセージのcontentをプロンプトとする単純な例。
|
| 72 |
+
user_prompt = ""
|
| 73 |
+
if request.messages and request.messages[-1].role == "user":
|
| 74 |
+
user_prompt = request.messages[-1].content
|
| 75 |
+
elif (
|
| 76 |
+
request.messages
|
| 77 |
+
): # 最後のメッセージがuserでない場合でも、何らかのテキストを取得
|
| 78 |
+
user_prompt = "\n".join(
|
| 79 |
+
[msg.content for msg in request.messages if msg.content]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if not user_prompt:
|
| 83 |
+
raise HTTPException(status_code=400, detail="No user prompt found in messages.")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# model_loader.py の推論関数を呼び出す
|
| 87 |
+
generated_content = generate_text(
|
| 88 |
+
prompt=user_prompt,
|
| 89 |
+
max_new_tokens=request.max_tokens
|
| 90 |
+
if request.max_tokens is not None
|
| 91 |
+
else 1024, # HFはmax_new_tokens
|
| 92 |
+
temperature=request.temperature if request.temperature is not None else 0.7,
|
| 93 |
+
top_p=request.top_p if request.top_p is not None else 0.9,
|
| 94 |
+
# repetition_penalty など他のパラメータも渡せるように拡張可能
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# OpenAI互換のレスポンスを作成
|
| 98 |
+
response_message = ChatCompletionResponseMessage(
|
| 99 |
+
role="assistant", content=generated_content
|
| 100 |
+
)
|
| 101 |
+
choice = ChatCompletionChoice(
|
| 102 |
+
index=0, message=response_message, finish_reason="stop"
|
| 103 |
+
)
|
| 104 |
+
# usage はダミー (正確なトークン数は別途計算が必要)
|
| 105 |
+
# usage = Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
| 106 |
+
|
| 107 |
+
return ChatCompletionResponse(
|
| 108 |
+
id="chatcmpl-" + uuid.uuid4().hex, # 一意なIDを生成
|
| 109 |
+
object="chat.completion",
|
| 110 |
+
created=int(time.time()),
|
| 111 |
+
model=LOADED_MODEL_ID, # 実際に使ったモデルID
|
| 112 |
+
choices=[choice],
|
| 113 |
+
# usage=usage,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
except RuntimeError as e:
|
| 117 |
+
print(f"RuntimeError during generation: {e}")
|
| 118 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Unexpected error during generation: {e}")
|
| 121 |
+
import traceback
|
| 122 |
+
|
| 123 |
+
traceback.print_exc()
|
| 124 |
+
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@app.get("/health")
|
| 128 |
+
async def health_check():
|
| 129 |
+
"""
|
| 130 |
+
ヘルスチェックエンドポイント。モデルがロードされていればOKを返す。
|
| 131 |
+
"""
|
| 132 |
+
if model_loader.model is not None and model_loader.tokenizer is not None:
|
| 133 |
+
return {"status": "ok", "model_loaded": LOADED_MODEL_ID}
|
| 134 |
+
else:
|
| 135 |
+
return JSONResponse(
|
| 136 |
+
status_code=503,
|
| 137 |
+
content={
|
| 138 |
+
"status": "error",
|
| 139 |
+
"message": f"Model {LOADED_MODEL_ID} not loaded or failed to load.",
|
| 140 |
+
},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@app.get("/")
|
| 145 |
+
async def root():
|
| 146 |
+
return {
|
| 147 |
+
"message": f"OpenAI Compatible API for model: {LOADED_MODEL_ID}. Use POST /v1/chat/completions."
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# --- Uvicornでの実行 (ローカルテスト用、DockerfileのCMDで上書きされる) ---
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
# 環境変数からモデルIDを読み込む (ローカルテスト時に設定)
|
| 154 |
+
# 例: export MODEL_ID="google/gemma-2b-it"
|
| 155 |
+
# python main.py
|
| 156 |
+
port = int(os.environ.get("PORT", 8000)) # DockerfileのEXPOSE/CMDと合わせる
|
| 157 |
+
print(
|
| 158 |
+
f"Starting Uvicorn server on port {port} for model '{os.environ.get('MODEL_ID', 'default_model_id_from_env')}'"
|
| 159 |
+
)
|
| 160 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
memo
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GGUF いろんな量子化がある Q4,Q5, ..
|
| 2 |
+
FPTQ 4bit 量子化が多い
|
| 3 |
+
AWQ 主に 4bit
|
| 4 |
+
|
| 5 |
+
32B は GGUF でどちらもある
|
| 6 |
+
https://huggingface.co/BlackBeenie/Qwen3-32B-Q4_K_M-GGUF
|
| 7 |
+
https://huggingface.co/kaitchup/Qwen3-32B-autoround-4bit-gptq
|
| 8 |
+
https://huggingface.co/BenevolenceMessiah/Qwen3-32B-Q8_0-GGUF
|
| 9 |
+
|
| 10 |
+
https://huggingface.co/charlesthefool/Qwen3-30B-A3B-Q4_K_M-GGUF
|
| 11 |
+
https://huggingface.co/BenevolenceMessiah/Qwen3-30B-A3B-Q8_0-GGUF
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
https://huggingface.co/mlx-community/Qwen3-32B-8bit
|
| 15 |
+
https://huggingface.co/unsloth/Qwen3-32B-bnb-4bit
|
| 16 |
+
|
| 17 |
+
https://huggingface.co/mlx-community/Qwen3-30B-A3B-8bit/blob/main/config.json
|
| 18 |
+
https://huggingface.co/unsloth/Qwen3-30B-A3B-bnb-4bit/tree/main
|
| 19 |
+
|
| 20 |
+
ちゃんとしたモデルをCUDA使って動かそうとしたらDockerのベースイメージをちゃんと選ぶ必要ある。
|
| 21 |
+
|
| 22 |
+
Dockerfileの作成
|
| 23 |
+
|
| 24 |
+
curl -X POST "https://yheye43-Eval-Qwen3-30B-A3B-GPTQ-Int4.hf.space/v1/chat/completions" \
|
| 25 |
+
-H "Content-Type: application/json" \
|
| 26 |
+
-d '{
|
| 27 |
+
"model": "stabilityai/japanese-stablelm-instruct-gamma-7b",
|
| 28 |
+
"messages": [
|
| 29 |
+
{"role": "user", "content": "日本の首都はどこですか? /nothink"}
|
| 30 |
+
],
|
| 31 |
+
"max_tokens": 50,
|
| 32 |
+
"temperature": 0.7
|
| 33 |
+
}'
|
| 34 |
+
{"id":"chatcmpl-6b73cd9660694171aa1064b33a14e8d9","object":"chat.completion","created":1747841643,"model":"Qwen/Qwen3-8B","choices":[{"index":0,"message":{"role":"ass
|
model_loader.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_loader.py
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# --- グローバル変数 (アプリケーション起動時にロードされる) ---
|
| 7 |
+
model = None
|
| 8 |
+
tokenizer = None
|
| 9 |
+
MODEL_ID = os.environ.get(
|
| 10 |
+
"MODEL_ID", "Qwen/Qwen3-30B-A3B"
|
| 11 |
+
) # 環境変数からモデルIDを取得、なければデフォルト
|
| 12 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
|
| 14 |
+
LOAD_IN_4BIT = os.environ.get("LOAD_IN_4BIT", "false").lower() == "true"
|
| 15 |
+
LOAD_IN_8BIT = os.environ.get("LOAD_IN_8BIT", "false").lower() == "true"
|
| 16 |
+
|
| 17 |
+
# 4bitと8bitが同時にTrueになるのを防ぐ (どちらか一方、またはどちらもFalse)
|
| 18 |
+
if LOAD_IN_4BIT and LOAD_IN_8BIT:
|
| 19 |
+
print(
|
| 20 |
+
"Warning: Both LOAD_IN_4BIT and LOAD_IN_8BIT are set to true. Prioritizing 4-bit."
|
| 21 |
+
)
|
| 22 |
+
LOAD_IN_8BIT = False
|
| 23 |
+
elif not LOAD_IN_4BIT and not LOAD_IN_8BIT:
|
| 24 |
+
print(
|
| 25 |
+
"Info: No explicit quantization (4-bit/8-bit) requested via environment variables. Loading in default precision (e.g., bfloat16 on GPU)."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_model():
|
| 30 |
+
"""
|
| 31 |
+
アプリケーション起動時にモデルとトークナイザーをロードする。
|
| 32 |
+
"""
|
| 33 |
+
global model, tokenizer
|
| 34 |
+
if model is None or tokenizer is None:
|
| 35 |
+
quantization_info = "No Quantization"
|
| 36 |
+
if LOAD_IN_4BIT:
|
| 37 |
+
quantization_info = "4-bit Quantization"
|
| 38 |
+
elif LOAD_IN_8BIT:
|
| 39 |
+
quantization_info = "8-bit Quantization"
|
| 40 |
+
|
| 41 |
+
print(
|
| 42 |
+
f"Loading model: {MODEL_ID} on device: {DEVICE} with {quantization_info}..."
|
| 43 |
+
)
|
| 44 |
+
try:
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 46 |
+
model_kwargs = {
|
| 47 |
+
"trust_remote_code": True
|
| 48 |
+
} # 基本的にTrueにしておくことが多い
|
| 49 |
+
quantization_config = None
|
| 50 |
+
if DEVICE == "cuda":
|
| 51 |
+
model_kwargs["device_map"] = "auto"
|
| 52 |
+
if LOAD_IN_4BIT:
|
| 53 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
| 54 |
+
model_kwargs["torch_dtype"] = "auto" # 4bitと併用する計算時の型
|
| 55 |
+
# bnb_4bit_compute_dtype など、より詳細なbitsandbytes設定も環境変数で制御可能
|
| 56 |
+
elif LOAD_IN_8BIT:
|
| 57 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 58 |
+
# 8bitの場合、torch_dtypeは自動で設定されることが多いが、明示も可
|
| 59 |
+
else: # 量子化なしGPU
|
| 60 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 61 |
+
|
| 62 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
+
# MODEL_ID,
|
| 64 |
+
# torch_dtype=torch.bfloat16, # または torch.float16
|
| 65 |
+
# load_in_4bit=True, # 4ビット量子化でロード (bitsandbytesが必要)
|
| 66 |
+
# # load_in_8bit=True, # 8ビット量子化の場合
|
| 67 |
+
# device_map="auto", # 自動でGPUに割り当て
|
| 68 |
+
# trust_remote_code=True, # モデルによっては必要
|
| 69 |
+
# )
|
| 70 |
+
else: # CPUの場合 (量子化はGPU推奨だが、一応対応)
|
| 71 |
+
# CPUでのbitsandbytes量子化は限定的、または非推奨
|
| 72 |
+
if LOAD_IN_4BIT or LOAD_IN_8BIT:
|
| 73 |
+
print(
|
| 74 |
+
"Warning: bitsandbytes quantization (4-bit/8-bit) is primarily for GPU. Attempting on CPU may be slow or unstable."
|
| 75 |
+
)
|
| 76 |
+
# model_kwargs["device_map"] = {"": "cpu"} # 明示的にCPUを指定
|
| 77 |
+
pass # .to(DEVICE) で対応
|
| 78 |
+
|
| 79 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
MODEL_ID, **model_kwargs, quantization_config=quantization_config
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if DEVICE == "cpu" and not (
|
| 84 |
+
LOAD_IN_4BIT or LOAD_IN_8BIT
|
| 85 |
+
): # CPUで量子化なしの場合
|
| 86 |
+
model = model.to(DEVICE)
|
| 87 |
+
|
| 88 |
+
model.eval() # 評価モード
|
| 89 |
+
print(f"Model {MODEL_ID} loaded successfully.")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error loading model {MODEL_ID}: {e}")
|
| 92 |
+
# エラー発生時は model と tokenizer が None のままになる
|
| 93 |
+
# アプリケーションのヘルスチェックなどでこれを確認できるようにするのも良い
|
| 94 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_text(
|
| 98 |
+
prompt: str,
|
| 99 |
+
max_new_tokens: int = 100,
|
| 100 |
+
temperature: float = 0.3,
|
| 101 |
+
top_p: float = 0.9,
|
| 102 |
+
repetition_penalty: float = 1.0,
|
| 103 |
+
) -> str:
|
| 104 |
+
"""
|
| 105 |
+
ロードされたモデルを使ってテキストを生成する。
|
| 106 |
+
"""
|
| 107 |
+
if model is None or tokenizer is None:
|
| 108 |
+
raise RuntimeError("Model not loaded. Cannot generate text.")
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# プロンプトの形式はモデルによって調整が必要
|
| 112 |
+
# 例: Instructモデルの場合、特定のテンプレートがあることが多い
|
| 113 |
+
# こ��では単純にユーザープロンプトのみを使用
|
| 114 |
+
# inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 115 |
+
|
| 116 |
+
# より一般的なチャット形式のプロンプト適用 (モデルに合わせて調整)
|
| 117 |
+
# StableLM Instruct Gamma のプロンプト形式例 (あくまで一例)
|
| 118 |
+
# 参考: https://huggingface.co/stabilityai/japanese-stablelm-instruct-gamma-7b
|
| 119 |
+
messages = [{"role": "user", "content": prompt}]
|
| 120 |
+
# モデルによっては tokenizer.apply_chat_template が使える
|
| 121 |
+
try:
|
| 122 |
+
# 多くのモデルではtokenizer.apply_chat_templateが使える
|
| 123 |
+
prompt_formatted = tokenizer.apply_chat_template(
|
| 124 |
+
messages,
|
| 125 |
+
tokenize=False,
|
| 126 |
+
add_generation_prompt=True,
|
| 127 |
+
# Thinking Modeの切り替えここでできる
|
| 128 |
+
# enable_thinking=False,
|
| 129 |
+
)
|
| 130 |
+
except Exception:
|
| 131 |
+
# 古いモデルや特殊なモデルでapply_chat_templateがない場合の手動フォーマット例
|
| 132 |
+
# これはモデルのドキュメントを確認して適切な形式にする
|
| 133 |
+
print(
|
| 134 |
+
f"Warning: tokenizer.apply_chat_template failed for {MODEL_ID}. Using raw prompt or basic formatting."
|
| 135 |
+
)
|
| 136 |
+
if (
|
| 137 |
+
"stablelm-instruct" in MODEL_ID.lower() or "elyza" in MODEL_ID.lower()
|
| 138 |
+
): # ELYZAやStableLMの例
|
| 139 |
+
prompt_formatted = f"ユーザー: {prompt}\nシステム: "
|
| 140 |
+
elif (
|
| 141 |
+
"qwen" in MODEL_ID.lower() and "chat" in MODEL_ID.lower()
|
| 142 |
+
): # Qwen-Chatの例
|
| 143 |
+
prompt_formatted = (
|
| 144 |
+
f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 145 |
+
)
|
| 146 |
+
else: # デフォルトはそのまま
|
| 147 |
+
prompt_formatted = prompt
|
| 148 |
+
|
| 149 |
+
inputs = tokenizer(
|
| 150 |
+
prompt_formatted, return_tensors="pt", add_special_tokens=False
|
| 151 |
+
).to(DEVICE) # add_special_tokensはテンプレートによる
|
| 152 |
+
|
| 153 |
+
# テキスト生成
|
| 154 |
+
# pad_token_id はeos_token_idと同じに設定することが多い (警告抑制)
|
| 155 |
+
if tokenizer.pad_token_id is None:
|
| 156 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 157 |
+
|
| 158 |
+
generation_kwargs = {
|
| 159 |
+
"max_new_tokens": max_new_tokens,
|
| 160 |
+
"temperature": temperature,
|
| 161 |
+
"top_p": top_p,
|
| 162 |
+
"repetition_penalty": repetition_penalty,
|
| 163 |
+
"do_sample": True
|
| 164 |
+
if temperature > 0
|
| 165 |
+
else False, # temperatureが0超ならサンプリング
|
| 166 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
outputs = model.generate(**inputs, **generation_kwargs)
|
| 170 |
+
|
| 171 |
+
# 生成されたテキストのみをデコード (入力プロンプト部分を除く)
|
| 172 |
+
# inputs.input_ids.shape[1] は入力トークンの長さ
|
| 173 |
+
output_text = tokenizer.decode(
|
| 174 |
+
outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
|
| 175 |
+
)
|
| 176 |
+
return output_text.strip()
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error during text generation: {e}")
|
| 180 |
+
# traceback.print_exc() # 詳細なエラー表示
|
| 181 |
+
raise RuntimeError(f"Text generation failed: {e}")
|