File size: 2,278 Bytes
f3bb5cb
 
2be91bc
cda905d
876c070
8ecdc1c
2be91bc
 
 
 
 
 
f3bb5cb
 
2be91bc
f3bb5cb
2be91bc
 
 
 
 
f3bb5cb
cda905d
f3bb5cb
 
 
2be91bc
 
 
 
 
 
8ecdc1c
 
2be91bc
 
 
 
8ecdc1c
 
 
2be91bc
 
 
 
 
f3bb5cb
2be91bc
f3bb5cb
2be91bc
 
 
 
 
 
 
 
 
 
 
f3bb5cb
2be91bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# models/loader.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from backend.agents import ROLE_PROMPTS
torch.set_num_threads(2)
# The following configs are no longer used for CPU, but kept for future GPU use.
QUANTIZATION_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

MODEL_REGISTRY = {
    "ceo": "Qwen/Qwen3-0.6B",
    "planner": "Qwen/Qwen3-0.6B",
    "manager": "Qwen/Qwen3-0.6B",
    "debugger": "Qwen/Qwen3-0.6B",
    "business_analyst": "Qwen/Qwen3-0.6B",
    "ux_ui_designer": "Qwen/Qwen3-0.6B",
    "worker_backend_coder": "Qwen/Qwen3-0.6B",
    "worker_front_end_coder": "Qwen/Qwen3-0.6B",
    "worker_tester": "Qwen/Qwen3-0.6B",
    "code_analyst": "Qwen/Qwen3-0.6B"
}
_MODEL_CACHE = {}

def get_model_and_tokenizer(model_name="Qwen/Qwen3-0.6B"):
    if model_name not in _MODEL_CACHE:
        print(f"Loading model: {model_name}...")
        _MODEL_CACHE[model_name] = {
            "model": AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map=None,
                quantization_config=None,
                trust_remote_code=True,
            ),
            "tokenizer": AutoTokenizer.from_pretrained(model_name)
        }
        # Explicitly move the model to the CPU after loading
        _MODEL_CACHE[model_name]["model"].to("cpu")
        
    return _MODEL_CACHE[model_name]["model"], _MODEL_CACHE[model_name]["tokenizer"]

def generate_with_model(agent_role, prompt):
    model_name = MODEL_REGISTRY.get(agent_role, "Qwen/Qwen3-0.6B")
    model, tokenizer = get_model_and_tokenizer(model_name)
    
    full_prompt = f"You are a helpful assistant. {ROLE_PROMPTS.get(agent_role, '')}\n\nUser prompt: {prompt}"
    
    input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1
        )
    
    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    return decoded_output.replace(full_prompt, "", 1).strip()