File size: 4,515 Bytes
7a10114
b9a427c
 
 
 
7a10114
 
 
 
 
 
b9a427c
 
 
7a10114
 
b9a427c
 
 
7a10114
b9a427c
 
 
 
 
 
7a10114
b9a427c
 
 
7a10114
 
 
 
 
 
 
 
 
b9a427c
 
 
7a10114
 
b9a427c
7a10114
b9a427c
 
 
 
 
7a10114
b9a427c
7a10114
 
 
 
b9a427c
 
 
 
 
7a10114
e2e2130
 
 
72a8dcc
7a10114
 
b9a427c
 
7a10114
b9a427c
 
 
 
 
 
7a10114
 
b9a427c
 
7a10114
 
 
 
 
 
b9a427c
 
 
2d01b34
e2e2130
 
 
2d01b34
 
 
 
 
 
 
 
b9a427c
7a10114
 
 
 
 
 
 
72a8dcc
7a10114
a70790f
 
 
7a10114
 
 
 
2d01b34
72a8dcc
2d01b34
7a10114
78310b8
7a10114
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import os

import modal
from fastapi import Header
from models import MODEL_IDS

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

CACHE_DIR = "/cache"

image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install("torch", "transformers", "accelerate", "fastapi", "bitsandbytes")
    .add_local_dir("site", "/root")
)

app = modal.App("posttraining-chat", image=image)
cache_vol = modal.Volume.from_name("hf-cache", create_if_missing=True)


@app.cls(
    gpu="T4",
    scaledown_window=60,
    secrets=[modal.Secret.from_dotenv()],
    volumes={CACHE_DIR: cache_vol},
)
class Inference:
    @modal.enter()
    def setup(self):
        os.environ["HF_HOME"] = CACHE_DIR
        self.models = {}

    def load_model(self, model_id: str):
        if model_id in self.models:
            logger.info(f"Model already loaded: {model_id}")
            return

        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        logger.info(f"Loading model: {model_id}")
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            logger.info(f"Tokenizer loaded for {model_id}")
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                device_map="auto",
            )
            logger.info(f"Model loaded successfully: {model_id}")
            self.models[model_id] = {"model": model, "tokenizer": tokenizer}
            cache_vol.commit()
        except Exception as e:
            logger.error(f"Failed to load model {model_id}: {e}")
            raise

    @modal.fastapi_endpoint(method="POST")
    def generate(self, request: dict, x_api_key: str | None = Header(None)) -> dict:
        import torch

        logger.info(
            f"Received request: model_id={request.get('model_id')}, "
            f"message_len={len(request.get('message', ''))}, "
            f"history_len={len(request.get('history', []))}"
            f"message: {request.get('message', '')}..."
        )

        expected_key = os.environ.get("MODEL_SITE_API_KEY")
        if not expected_key or x_api_key != expected_key:
            logger.warning("Auth failed: invalid or missing API key")
            return {"error": "Unauthorized - invalid API key"}

        model_id = request.get("model_id", MODEL_IDS[0])
        message = request.get("message", "")
        history = request.get("history", [])

        if model_id not in MODEL_IDS:
            logger.warning(f"Model not found: {model_id}")
            return {"error": f"Model {model_id} not found"}

        try:
            self.load_model(model_id)
        except Exception as e:
            logger.error(f"Model loading failed: {e}")
            return {"error": f"Failed to load model: {e}"}

        tokenizer = self.models[model_id]["tokenizer"]
        model = self.models[model_id]["model"]

        messages = []
        for msg in history:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            messages.append({"role": role, "content": content})
        messages.append({"role": "user", "content": message})

        conversation = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        try:
            inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
            logger.info(f"Tokenized input shape: {inputs['input_ids'].shape}")

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=1024,
                    do_sample=True,
                    temperature=0.4,
                    top_p=0.85,
                    repetition_penalty=1.15,
                    pad_token_id=tokenizer.eos_token_id,
                )
            logger.info(f"Generated output shape: {outputs.shape}")

            # Extract only the newly generated tokens (skip the input)
            new_tokens = outputs[0][inputs["input_ids"].shape[1] :]
            response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
            logger.info(f"Final response length: {len(response)}")
            logger.info(f"Response: {response}")

            return {"response": response}
        except Exception as e:
            logger.error(f"Inference failed: {e}", exc_info=True)
            return {"error": f"Inference failed: {e}"}