File size: 3,565 Bytes
f3bb5cb
 
dcd1081
 
 
 
 
 
cda905d
814e2eb
dcd1081
2be91bc
 
 
 
 
 
f3bb5cb
 
2be91bc
f3bb5cb
2be91bc
 
 
 
 
f3bb5cb
dcd1081
f3bb5cb
 
 
fa87e26
dcd1081
 
fa87e26
dcd1081
 
fa87e26
dcd1081
 
 
 
814e2eb
 
 
 
 
2be91bc
 
dcd1081
814e2eb
dcd1081
fa87e26
 
 
dcd1081
 
 
814e2eb
 
dcd1081
 
814e2eb
 
dcd1081
fa87e26
dcd1081
 
 
 
 
 
 
 
2be91bc
 
dcd1081
 
814e2eb
 
 
2be91bc
 
dcd1081
2be91bc
dcd1081
fa87e26
 
 
 
 
 
 
 
 
dcd1081
2be91bc
 
 
fa87e26
dcd1081
 
2be91bc
dcd1081
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# models/loader.py
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
)
from backend.agents import ROLE_PROMPTS

# Optional quantization config (used only if GPU is available)
QUANTIZATION_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

MODEL_REGISTRY = {
    "ceo": "Qwen/Qwen3-0.6B",
    "planner": "Qwen/Qwen3-0.6B",
    "manager": "Qwen/Qwen3-0.6B",
    "debugger": "Qwen/Qwen3-0.6B",
    "business_analyst": "Qwen/Qwen3-0.6B",
    "ux_ui_designer": "Qwen/Qwen3-0.6B",
    "worker_backend_coder": "Qwen/Qwen3-0.6B",
    "worker_front_end_coder": "Qwen/Qwen3-0.6B",
    "worker_tester": "Qwen/Qwen3-0.6B",
    "code_analyst": "Qwen/Qwen3-0.6B",
}
_MODEL_CACHE = {}

# Explicit generation config (avoids model-specific overrides)
GENERATION_CONFIG = GenerationConfig(
    max_new_tokens=512,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    top_k=50,
    repetition_penalty=1.1,
)


def get_model_and_tokenizer(model_name):
    """
    Loads a model and its tokenizer from the Hugging Face Hub.
    Implements caching to avoid reloading the model for each call.
    """
    if model_name not in _MODEL_CACHE:
        print(f"Loading model: {model_name}...")

        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

        # Ensure a dedicated pad token exists (not EOS)
        if tokenizer.pad_token is None or tokenizer.pad_token == tokenizer.eos_token:
            tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

        # Load model with GPU/CPU awareness
        use_gpu = torch.cuda.is_available()
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto" if use_gpu else None,
            quantization_config=QUANTIZATION_CONFIG if use_gpu else None,
            trust_remote_code=True,
        )

        # Resize embeddings if new tokens were added
        model.resize_token_embeddings(len(tokenizer))

        # Explicitly move to CPU if no GPU
        if not use_gpu:
            model.to("cpu")

        _MODEL_CACHE[model_name] = {"model": model, "tokenizer": tokenizer}

    return _MODEL_CACHE[model_name]["model"], _MODEL_CACHE[model_name]["tokenizer"]


def generate_with_model(agent_role, prompt, generation_config: GenerationConfig = GENERATION_CONFIG):
    """
    Generates a response using the specified agent's model.
    """
    model_name = MODEL_REGISTRY.get(agent_role, "Qwen/Qwen3-0.6B")
    model, tokenizer = get_model_and_tokenizer(model_name)

    full_prompt = f"You are a helpful assistant. {ROLE_PROMPTS.get(agent_role, '')}\n\nUser prompt: {prompt}"

    # Use tokenizer(...) to get both input_ids and attention_mask
    inputs = tokenizer(
        full_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids,
            attention_mask=attention_mask,   # ✅ ensures padding is ignored
            generation_config=generation_config,
            pad_token_id=tokenizer.pad_token_id,
        )

    # Slice off the prompt tokens to avoid prompt-echo issues
    generated_tokens = output[0][input_ids.shape[-1]:]
    decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    return decoded_output