posttraining-practice / backend.py
sachiniyer's picture
Upload folder using huggingface_hub
72a8dcc verified
import logging
import os
import modal
from fastapi import Header
from models import MODEL_IDS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CACHE_DIR = "/cache"
image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install("torch", "transformers", "accelerate", "fastapi", "bitsandbytes")
.add_local_dir("site", "/root")
)
app = modal.App("posttraining-chat", image=image)
cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True)
@app.cls(
gpu="T4",
scaledown_window=60,
secrets=[modal.Secret.from_dotenv()],
volumes={CACHE_DIR: cache_vol},
)
class Inference:
@modal.enter()
def setup(self):
os.environ["HF_HOME"] = CACHE_DIR
self.models = {}
def load_model(self, model_id: str):
if model_id in self.models:
logger.info(f"Model already loaded: {model_id}")
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info(f"Loading model: {model_id}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
logger.info(f"Tokenizer loaded for {model_id}")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
)
logger.info(f"Model loaded successfully: {model_id}")
self.models[model_id] = {"model": model, "tokenizer": tokenizer}
cache_vol.commit()
except Exception as e:
logger.error(f"Failed to load model {model_id}: {e}")
raise
@modal.fastapi_endpoint(method="POST")
def generate(self, request: dict, x_api_key: str | None = Header(None)) -> dict:
import torch
logger.info(
f"Received request: model_id={request.get('model_id')}, "
f"message_len={len(request.get('message', ''))}, "
f"history_len={len(request.get('history', []))}"
f"message: {request.get('message', '')}..."
)
expected_key = os.environ.get("MODEL_SITE_API_KEY")
if not expected_key or x_api_key != expected_key:
logger.warning("Auth failed: invalid or missing API key")
return {"error": "Unauthorized - invalid API key"}
model_id = request.get("model_id", MODEL_IDS[0])
message = request.get("message", "")
history = request.get("history", [])
if model_id not in MODEL_IDS:
logger.warning(f"Model not found: {model_id}")
return {"error": f"Model {model_id} not found"}
try:
self.load_model(model_id)
except Exception as e:
logger.error(f"Model loading failed: {e}")
return {"error": f"Failed to load model: {e}"}
tokenizer = self.models[model_id]["tokenizer"]
model = self.models[model_id]["model"]
messages = []
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": message})
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
try:
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
logger.info(f"Tokenized input shape: {inputs['input_ids'].shape}")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.4,
top_p=0.85,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
)
logger.info(f"Generated output shape: {outputs.shape}")
# Extract only the newly generated tokens (skip the input)
new_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
logger.info(f"Final response length: {len(response)}")
logger.info(f"Response: {response}")
return {"response": response}
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
return {"error": f"Inference failed: {e}"}