| """Test hard queries with V2 adapter.""" |
| import torch, time |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
|
|
| MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" |
| ADAPTER_DIR = "./adapter-model" |
|
|
| print("Loading...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True) |
| model = PeftModel.from_pretrained(base_model, ADAPTER_DIR) |
| model.eval() |
|
|
| hard_queries = [ |
| |
| "quantos contratos cada turma tem", |
| "usuarios sem contrato", |
| "top 10 devedores com mais parcelas vencidas", |
| "receita total por mes nos ultimos 6 meses", |
| "parcelas vencidas com nome do aluno e valor", |
| "quantos participantes cada turma tem", |
| "planos com mais contratos ativos", |
| "boletos pendentes com nome do usuario", |
| |
| "rescisoes pendentes com nome e valor", |
| "saldo total das wallets", |
| "renegociacoes do mes atual", |
| "transferencias pendentes", |
| |
| "quais colunas tem a tabela eventos", |
| "mostra o schema completo do banco", |
| "foreign keys entre as tabelas", |
| "nao conheco essa tabela", |
| |
| "quantos alunos inaptos financeiramente", |
| "cartoes de credito ativos por bandeira", |
| "valor medio dos planos por categoria", |
| "campanhas de desconto vigentes", |
| ] |
|
|
| for q in hard_queries: |
| prompt = f"<|im_start|>system\nYou are a command adapter. Output ONLY valid JSON. No explanation.<|im_end|>\n<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant\n" |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| t0 = time.time() |
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=True, pad_token_id=tokenizer.eos_token_id) |
| elapsed = time.time() - t0 |
| response = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() |
| |
| if response.count('}') > 0: |
| response = response[:response.index('}') + 1] |
| print(f"[{elapsed:.1f}s] {q}") |
| print(f" > {response}") |
| print() |
|
|