Fix model loading for repo ID and improve GPU detection
Browse files- app.py +18 -1
- package/ai.py +8 -0
app.py
CHANGED
|
@@ -142,7 +142,24 @@ app = FastAPI(
|
|
| 142 |
)
|
| 143 |
|
| 144 |
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
@app.get("/")
|
| 147 |
def root() -> Dict[str, str]:
|
| 148 |
"""簡易案内"""
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
|
| 145 |
+
# ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
|
| 146 |
+
@spaces.GPU
|
| 147 |
+
def _gpu_init_function():
|
| 148 |
+
"""GPU初期化用のダミー関数(Space起動時に検出される)"""
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@app.on_event("startup")
|
| 153 |
+
async def startup_event():
|
| 154 |
+
"""アプリ起動時の処理(GPU要求を確実に検出させる)"""
|
| 155 |
+
if SPACES_AVAILABLE:
|
| 156 |
+
try:
|
| 157 |
+
_gpu_init_function()
|
| 158 |
+
print("[SPACE] GPU要求をstartup eventで送信しました")
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"[SPACE] GPU要求エラー: {e}")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
@app.get("/")
|
| 164 |
def root() -> Dict[str, str]:
|
| 165 |
"""簡易案内"""
|
package/ai.py
CHANGED
|
@@ -51,6 +51,14 @@ class AI:
|
|
| 51 |
try:
|
| 52 |
if not model_path:
|
| 53 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# transformersを使用してモデルをロード
|
| 56 |
try:
|
|
|
|
| 51 |
try:
|
| 52 |
if not model_path:
|
| 53 |
return None
|
| 54 |
+
|
| 55 |
+
# モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
|
| 56 |
+
is_repo_id = "/" in model_path and not os.path.exists(model_path)
|
| 57 |
+
|
| 58 |
+
# リポジトリIDの場合は os.path.exists() チェックをスキップ
|
| 59 |
+
if not is_repo_id and not os.path.exists(model_path):
|
| 60 |
+
print(f"[AI] モデルパスが存在しません: {model_path}")
|
| 61 |
+
return None
|
| 62 |
|
| 63 |
# transformersを使用してモデルをロード
|
| 64 |
try:
|