File size: 6,681 Bytes
6524169
40a90bb
6524169
 
 
 
 
40a90bb
6524169
 
 
 
 
 
 
 
 
 
40a90bb
6524169
40a90bb
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81b01a7
6524169
 
 
40a90bb
24492a8
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81b01a7
6524169
 
81b01a7
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a90bb
 
 
 
 
6524169
 
 
 
 
 
 
 
 
 
40a90bb
6524169
 
 
 
 
 
40a90bb
 
 
 
 
 
 
 
 
 
 
81b01a7
40a90bb
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a90bb
 
6524169
 
 
 
40a90bb
6524169
 
40a90bb
6524169
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
Intent extraction API on Modal (FastAPI + globally loaded QLoRA model).

Prerequisites:
    pip install modal
    pip install -r modal_apps/requirements-modal.txt
    modal setup
    modal run modal_apps/train_modal.py --dataset train_intent.jsonl

Deploy (persistent HTTPS endpoint):
    modal deploy modal_apps/predict_api.py

The deploy command prints the public URL, e.g.:
    https://<workspace>--android-skill-predict-api-skillpredictor-web.modal.run

Test:
    curl -X POST https://<workspace>--android-skill-predict-api-skillpredictor-web.modal.run/predict \\
      -H "Content-Type: application/json" \\
      -d '{"prompt": "text mom on whatsapp i am on my way"}'

    # -> {"skill":"whatsapp_send_message","parameters":{"contact":"mom","message":"i am on my way"}}

Develop locally (ephemeral URL, hot-reloads on file changes):
    modal serve modal_apps/predict_api.py

Stop a deployed app:
    modal app stop android-skill-predict-api
"""

from __future__ import annotations

import pathlib

import modal

app = modal.App("android-skill-predict-api")

# ---------------------------------------------------------------------------
# Configuration (same volume + model paths as train_modal.py)
# ---------------------------------------------------------------------------

MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
MODEL_DIR = pathlib.Path("/model")
ADAPTER_DIR = MODEL_DIR / "adapter"
MAX_SEQ_LENGTH = 2048
MAX_NEW_TOKENS = 128

GPU_TYPE = "A10G"
TIMEOUT_SECONDS = 10 * 60
SCALEDOWN_WINDOW = 5 * 60

model_volume = modal.Volume.from_name(
    "android-dataset-model",
    create_if_missing=True,
)
model_cache_volume = modal.Volume.from_name(
    "android-dataset-hf-cache",
    create_if_missing=True,
)

api_image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install_from_requirements(
        str(pathlib.Path(__file__).parent / "requirements-modal.txt")
    )
    .pip_install("fastapi")
    .env(
        {
            "HF_HOME": "/model_cache",
            "HF_HUB_ENABLE_HF_TRANSFER": "1",
            "PYTHONPATH": "/root/src",
        }
    )
    .add_local_dir(str(PROJECT_ROOT / "src"), remote_path="/root/src", copy=True)
)

with api_image.imports():
    import unsloth  # noqa: F401 — must import before trl/transformers/peft

    import torch
    from peft import PeftModel
    from unsloth import FastLanguageModel
    from unsloth.chat_templates import get_chat_template


# ---------------------------------------------------------------------------
# API (model loaded once per container via @modal.enter)
# ---------------------------------------------------------------------------


@app.cls(
    image=api_image,
    gpu=GPU_TYPE,
    timeout=TIMEOUT_SECONDS,
    scaledown_window=SCALEDOWN_WINDOW,
    volumes={
        "/model": model_volume,
        "/model_cache": model_cache_volume,
    },
)
@modal.concurrent(max_inputs=1)
class SkillPredictor:
    @modal.enter()
    def load_model(self) -> None:
        model_volume.reload()

        if not (ADAPTER_DIR / "adapter_config.json").exists():
            raise FileNotFoundError(
                f"LoRA adapter not found at {ADAPTER_DIR}. "
                "Run `modal run modal_apps/train_modal.py` first."
            )

        print(f"Loading base model: {MODEL_NAME}")
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=MODEL_NAME,
            max_seq_length=MAX_SEQ_LENGTH,
            dtype=None,
            load_in_4bit=True,
        )
        print(f"Loading LoRA adapter from {ADAPTER_DIR}")
        model = PeftModel.from_pretrained(model, str(ADAPTER_DIR))
        tokenizer = get_chat_template(
            tokenizer,
            chat_template="qwen-2.5",
        )
        FastLanguageModel.for_inference(model)

        self.model = model
        self.tokenizer = tokenizer
        print("Model ready.")

    def _generate_intent(self, prompt: str) -> dict | None:
        from classifier_prompt import build_intent_messages
        from skill_utils import extract_intent, resolve_skill

        messages = build_intent_messages(prompt)
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        with torch.inference_mode():
            outputs = self.model.generate(
                input_ids=inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                use_cache=True,
                do_sample=False,
            )

        generated = outputs[0][inputs.shape[1] :]
        raw_output = self.tokenizer.decode(generated, skip_special_tokens=True).strip()
        intent = extract_intent(raw_output)
        if not intent:
            return None

        skill = resolve_skill(intent.get("skill"), prompt)
        if not skill:
            return None

        parameters = intent.get("parameters", {})
        if not isinstance(parameters, dict):
            parameters = {}

        return {"skill": skill, "parameters": parameters}

    @modal.asgi_app()
    def web(self):
        from starlette.applications import Starlette
        from starlette.responses import JSONResponse
        from starlette.routing import Route

        async def predict(request):
            try:
                data = await request.json()
            except Exception:
                return JSONResponse(
                    status_code=422,
                    content={
                        "error": "invalid_request",
                        "message": "Request body must be JSON with a 'prompt' field.",
                    },
                )

            prompt = data.get("prompt") if isinstance(data, dict) else None
            if not isinstance(prompt, str) or not prompt.strip():
                return JSONResponse(
                    status_code=422,
                    content={
                        "error": "invalid_request",
                        "message": "Field 'prompt' is required.",
                    },
                )

            intent = self._generate_intent(prompt.strip())
            if intent is None:
                return JSONResponse(
                    status_code=422,
                    content={
                        "error": "invalid_model_output",
                        "message": "Model did not return valid intent JSON.",
                    },
                )
            return JSONResponse(content=intent)

        return Starlette(
            routes=[
                Route("/predict", predict, methods=["POST"]),
            ]
        )