Keeby-smilyai's picture
Update models/loader.py
2be91bc verified
raw
history blame
2.09 kB
# models/loader.py
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# Configuration for loading models
QUANTIZATION_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Use a registry to map agent roles to specific models
MODEL_REGISTRY = {
"ceo": "Qwen/Qwen3-0.6B",
"planner": "Qwen/Qwen3-0.6B",
"manager": "Qwen/Qwen3-0.6B",
"debugger": "Qwen/Qwen3-0.6B",
"business_analyst": "Qwen/Qwen3-0.6B",
"ux_ui_designer": "Qwen/Qwen3-0.6B",
"worker_backend_coder": "Qwen/Qwen3-0.6B",
"worker_front_end_coder": "Qwen/Qwen3-0.6B",
"worker_tester": "Qwen/Qwen3-0.6B",
}
_MODEL_CACHE = {}
def get_model_and_tokenizer(model_name="Qwen/Qwen3-0.6B"):
if model_name not in _MODEL_CACHE:
print(f"Loading model: {model_name}...")
_MODEL_CACHE[model_name] = {
"model": AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=QUANTIZATION_CONFIG,
trust_remote_code=True,
),
"tokenizer": AutoTokenizer.from_pretrained(model_name)
}
return _MODEL_CACHE[model_name]["model"], _MODEL_CACHE[model_name]["tokenizer"]
def generate_with_model(agent_role, prompt):
model_name = MODEL_REGISTRY.get(agent_role, "Qwen/Qwen3-0.6B")
model, tokenizer = get_model_and_tokenizer(model_name)
full_prompt = f"You are a helpful assistant. {ROLE_PROMPTS.get(agent_role, '')}\n\nUser prompt: {prompt}"
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=2048,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
return decoded_output.replace(full_prompt, "", 1).strip()