Spaces:
Sleeping
Sleeping
File size: 4,515 Bytes
7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 e2e2130 72a8dcc 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 7a10114 b9a427c 2d01b34 e2e2130 2d01b34 b9a427c 7a10114 72a8dcc 7a10114 a70790f 7a10114 2d01b34 72a8dcc 2d01b34 7a10114 78310b8 7a10114 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import logging
import os
import modal
from fastapi import Header
from models import MODEL_IDS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CACHE_DIR = "/cache"
image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install("torch", "transformers", "accelerate", "fastapi", "bitsandbytes")
.add_local_dir("site", "/root")
)
app = modal.App("posttraining-chat", image=image)
cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True)
@app.cls(
gpu="T4",
scaledown_window=60,
secrets=[modal.Secret.from_dotenv()],
volumes={CACHE_DIR: cache_vol},
)
class Inference:
@modal.enter()
def setup(self):
os.environ["HF_HOME"] = CACHE_DIR
self.models = {}
def load_model(self, model_id: str):
if model_id in self.models:
logger.info(f"Model already loaded: {model_id}")
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info(f"Loading model: {model_id}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
logger.info(f"Tokenizer loaded for {model_id}")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
)
logger.info(f"Model loaded successfully: {model_id}")
self.models[model_id] = {"model": model, "tokenizer": tokenizer}
cache_vol.commit()
except Exception as e:
logger.error(f"Failed to load model {model_id}: {e}")
raise
@modal.fastapi_endpoint(method="POST")
def generate(self, request: dict, x_api_key: str | None = Header(None)) -> dict:
import torch
logger.info(
f"Received request: model_id={request.get('model_id')}, "
f"message_len={len(request.get('message', ''))}, "
f"history_len={len(request.get('history', []))}"
f"message: {request.get('message', '')}..."
)
expected_key = os.environ.get("MODEL_SITE_API_KEY")
if not expected_key or x_api_key != expected_key:
logger.warning("Auth failed: invalid or missing API key")
return {"error": "Unauthorized - invalid API key"}
model_id = request.get("model_id", MODEL_IDS[0])
message = request.get("message", "")
history = request.get("history", [])
if model_id not in MODEL_IDS:
logger.warning(f"Model not found: {model_id}")
return {"error": f"Model {model_id} not found"}
try:
self.load_model(model_id)
except Exception as e:
logger.error(f"Model loading failed: {e}")
return {"error": f"Failed to load model: {e}"}
tokenizer = self.models[model_id]["tokenizer"]
model = self.models[model_id]["model"]
messages = []
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": message})
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
try:
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
logger.info(f"Tokenized input shape: {inputs['input_ids'].shape}")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.4,
top_p=0.85,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
)
logger.info(f"Generated output shape: {outputs.shape}")
# Extract only the newly generated tokens (skip the input)
new_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
logger.info(f"Final response length: {len(response)}")
logger.info(f"Response: {response}")
return {"response": response}
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
return {"error": f"Inference failed: {e}"}
|