android-skill-router / modal_apps /predict_api.py
kriyanshi's picture
Ship v2 intent extraction with API, demo UI, eval, and benchmark suite.
40a90bb
Raw
History Blame Contribute Delete
6.68 kB
"""
Intent extraction API on Modal (FastAPI + globally loaded QLoRA model).
Prerequisites:
pip install modal
pip install -r modal_apps/requirements-modal.txt
modal setup
modal run modal_apps/train_modal.py --dataset train_intent.jsonl
Deploy (persistent HTTPS endpoint):
modal deploy modal_apps/predict_api.py
The deploy command prints the public URL, e.g.:
https://<workspace>--android-skill-predict-api-skillpredictor-web.modal.run
Test:
curl -X POST https://<workspace>--android-skill-predict-api-skillpredictor-web.modal.run/predict \\
-H "Content-Type: application/json" \\
-d '{"prompt": "text mom on whatsapp i am on my way"}'
# -> {"skill":"whatsapp_send_message","parameters":{"contact":"mom","message":"i am on my way"}}
Develop locally (ephemeral URL, hot-reloads on file changes):
modal serve modal_apps/predict_api.py
Stop a deployed app:
modal app stop android-skill-predict-api
"""
from __future__ import annotations
import pathlib
import modal
app = modal.App("android-skill-predict-api")
# ---------------------------------------------------------------------------
# Configuration (same volume + model paths as train_modal.py)
# ---------------------------------------------------------------------------
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
MODEL_DIR = pathlib.Path("/model")
ADAPTER_DIR = MODEL_DIR / "adapter"
MAX_SEQ_LENGTH = 2048
MAX_NEW_TOKENS = 128
GPU_TYPE = "A10G"
TIMEOUT_SECONDS = 10 * 60
SCALEDOWN_WINDOW = 5 * 60
model_volume = modal.Volume.from_name(
"android-dataset-model",
create_if_missing=True,
)
model_cache_volume = modal.Volume.from_name(
"android-dataset-hf-cache",
create_if_missing=True,
)
api_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install_from_requirements(
str(pathlib.Path(__file__).parent / "requirements-modal.txt")
)
.pip_install("fastapi")
.env(
{
"HF_HOME": "/model_cache",
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"PYTHONPATH": "/root/src",
}
)
.add_local_dir(str(PROJECT_ROOT / "src"), remote_path="/root/src", copy=True)
)
with api_image.imports():
import unsloth # noqa: F401 — must import before trl/transformers/peft
import torch
from peft import PeftModel
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
# ---------------------------------------------------------------------------
# API (model loaded once per container via @modal.enter)
# ---------------------------------------------------------------------------
@app.cls(
image=api_image,
gpu=GPU_TYPE,
timeout=TIMEOUT_SECONDS,
scaledown_window=SCALEDOWN_WINDOW,
volumes={
"/model": model_volume,
"/model_cache": model_cache_volume,
},
)
@modal.concurrent(max_inputs=1)
class SkillPredictor:
@modal.enter()
def load_model(self) -> None:
model_volume.reload()
if not (ADAPTER_DIR / "adapter_config.json").exists():
raise FileNotFoundError(
f"LoRA adapter not found at {ADAPTER_DIR}. "
"Run `modal run modal_apps/train_modal.py` first."
)
print(f"Loading base model: {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=None,
load_in_4bit=True,
)
print(f"Loading LoRA adapter from {ADAPTER_DIR}")
model = PeftModel.from_pretrained(model, str(ADAPTER_DIR))
tokenizer = get_chat_template(
tokenizer,
chat_template="qwen-2.5",
)
FastLanguageModel.for_inference(model)
self.model = model
self.tokenizer = tokenizer
print("Model ready.")
def _generate_intent(self, prompt: str) -> dict | None:
from classifier_prompt import build_intent_messages
from skill_utils import extract_intent, resolve_skill
messages = build_intent_messages(prompt)
inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
with torch.inference_mode():
outputs = self.model.generate(
input_ids=inputs,
max_new_tokens=MAX_NEW_TOKENS,
use_cache=True,
do_sample=False,
)
generated = outputs[0][inputs.shape[1] :]
raw_output = self.tokenizer.decode(generated, skip_special_tokens=True).strip()
intent = extract_intent(raw_output)
if not intent:
return None
skill = resolve_skill(intent.get("skill"), prompt)
if not skill:
return None
parameters = intent.get("parameters", {})
if not isinstance(parameters, dict):
parameters = {}
return {"skill": skill, "parameters": parameters}
@modal.asgi_app()
def web(self):
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
async def predict(request):
try:
data = await request.json()
except Exception:
return JSONResponse(
status_code=422,
content={
"error": "invalid_request",
"message": "Request body must be JSON with a 'prompt' field.",
},
)
prompt = data.get("prompt") if isinstance(data, dict) else None
if not isinstance(prompt, str) or not prompt.strip():
return JSONResponse(
status_code=422,
content={
"error": "invalid_request",
"message": "Field 'prompt' is required.",
},
)
intent = self._generate_intent(prompt.strip())
if intent is None:
return JSONResponse(
status_code=422,
content={
"error": "invalid_model_output",
"message": "Model did not return valid intent JSON.",
},
)
return JSONResponse(content=intent)
return Starlette(
routes=[
Route("/predict", predict, methods=["POST"]),
]
)