Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ----------------------------------------------------------------------------- | |
| # 設定 | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "google/gemma-3-4b-it" | |
| # Hugging Face token が必要な場合は環境変数 HUGGINGFACE_TOKEN をセット | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # ----------------------------------------------------------------------------- | |
| # デバイス設定(Spaces の無料枠では CPU のみです) | |
| # ----------------------------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------------------------------------------------------- | |
| # トークナイザーとモデルのロード | |
| # ----------------------------------------------------------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| use_auth_token=HF_TOKEN, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| use_auth_token=HF_TOKEN, | |
| torch_dtype=torch.float32, # CPU 環境では float32 | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| model.to(device) | |
| # ----------------------------------------------------------------------------- | |
| # FastAPI 定義 | |
| # ----------------------------------------------------------------------------- | |
| app = FastAPI(title="Gemma3-4B-IT API") | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 128 | |
| temperature: float = 0.8 | |
| top_p: float = 0.95 | |
| async def generate(req: GenerationRequest): | |
| if not req.prompt: | |
| raise HTTPException(status_code=400, detail="prompt は必須です。") | |
| # トークナイズ | |
| inputs = tokenizer( | |
| req.prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ).to(device) | |
| # 生成 | |
| generation_output = model.generate( | |
| **inputs, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| text = tokenizer.decode(generation_output[0], skip_special_tokens=True) | |
| return {"generated_text": text} | |
| # ----------------------------------------------------------------------------- | |
| # ローカル起動用 | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |