| | |
| | import torch |
| | import os |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| |
|
| | |
| | QUANTIZATION_CONFIG = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | ) |
| |
|
| | |
| | MODEL_REGISTRY = { |
| | "ceo": "Qwen/Qwen3-0.6B", |
| | "planner": "Qwen/Qwen3-0.6B", |
| | "manager": "Qwen/Qwen3-0.6B", |
| | "debugger": "Qwen/Qwen3-0.6B", |
| | "business_analyst": "Qwen/Qwen3-0.6B", |
| | "ux_ui_designer": "Qwen/Qwen3-0.6B", |
| | "worker_backend_coder": "Qwen/Qwen3-0.6B", |
| | "worker_front_end_coder": "Qwen/Qwen3-0.6B", |
| | "worker_tester": "Qwen/Qwen3-0.6B", |
| | } |
| | _MODEL_CACHE = {} |
| |
|
| | def get_model_and_tokenizer(model_name="Qwen/Qwen3-0.6B"): |
| | if model_name not in _MODEL_CACHE: |
| | print(f"Loading model: {model_name}...") |
| | _MODEL_CACHE[model_name] = { |
| | "model": AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | device_map="auto", |
| | quantization_config=QUANTIZATION_CONFIG, |
| | trust_remote_code=True, |
| | ), |
| | "tokenizer": AutoTokenizer.from_pretrained(model_name) |
| | } |
| | return _MODEL_CACHE[model_name]["model"], _MODEL_CACHE[model_name]["tokenizer"] |
| |
|
| | def generate_with_model(agent_role, prompt): |
| | model_name = MODEL_REGISTRY.get(agent_role, "Qwen/Qwen3-0.6B") |
| | model, tokenizer = get_model_and_tokenizer(model_name) |
| | |
| | full_prompt = f"You are a helpful assistant. {ROLE_PROMPTS.get(agent_role, '')}\n\nUser prompt: {prompt}" |
| | |
| | input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(model.device) |
| | |
| | with torch.no_grad(): |
| | output = model.generate( |
| | input_ids, |
| | max_new_tokens=2048, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.9, |
| | repetition_penalty=1.1 |
| | ) |
| | |
| | decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) |
| | return decoded_output.replace(full_prompt, "", 1).strip() |