File size: 2,467 Bytes
0e0974d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from huggingface_hub import hf_hub_download
from llama_cpp import Llama


class EndpointHandler:
    def __init__(self, path=""):
        # 1. Download GGUF files from the Hub
        self.qwen_4b_path = hf_hub_download(
            repo_id="AtomP/NewQwenTestCase", filename="qwen3-4b-instruct-2507.Q8_0.gguf"
        )
        self.qwen_7b_path = hf_hub_download(
            repo_id="unsloth/Qwen2.5-Coder-7B-Instruct-128K-GGUF",
            filename="Qwen2.5-Coder-7B-Instruct-Q8_0.gguf",
        )

        # 2. Load models into GPU memory (n_gpu_layers=-1 offloads all layers to GPU)
        # n_ctx limits the context window to save VRAM. Increase this if your GPU has capacity.
        self.model_4b = Llama(
            model_path=self.qwen_4b_path, n_gpu_layers=-1, n_ctx=8192, verbose=False
        )

        self.model_7b = Llama(
            model_path=self.qwen_7b_path, n_gpu_layers=-1, n_ctx=8192, verbose=False
        )

    def __call__(self, data):
        # 1. Hugging Face puts our JSON inside the "inputs" key.
        # We use .get("inputs", data) so it still works gracefully if tested locally.
        payload = data.get("inputs", data)

        # 2. Extract parameters from the payload (using .get instead of .pop is safer here)
        messages = payload.get("messages", [{"role": "user", "content": "Hello"}])
        target_model = payload.get("target_model", "test_case")
        max_tokens = payload.get("max_tokens", 512)
        temperature = payload.get("temperature", 0.7)
        response_format = payload.get("response_format", None)
        repeat_penalty = payload.get("repeat_penalty", 1.05)
        stop = payload.get("stop", ["<|im_end|>"])

        # 3. Route request
        if target_model == "test_case":
            active_model = self.model_4b
        elif target_model == "test_script":
            active_model = self.model_7b
        else:
            return {
                "error": f"Invalid target_model: '{target_model}'. Use 'test_case' or 'test_script'."
            }

        # 4. Generate and return response
        response = active_model.create_chat_completion(
            messages=messages, 
            max_tokens=max_tokens, 
            temperature=temperature,
            response_format=response_format, # Don't forget to pass this!
            repeat_penalty=repeat_penalty,
            stop=stop
        )

        return response