"""Test hard queries with V2 adapter.""" import torch, time from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" ADAPTER_DIR = "./adapter-model" print("Loading...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True) model = PeftModel.from_pretrained(base_model, ADAPTER_DIR) model.eval() hard_queries = [ # JOINs e agregacoes "quantos contratos cada turma tem", "usuarios sem contrato", "top 10 devedores com mais parcelas vencidas", "receita total por mes nos ultimos 6 meses", "parcelas vencidas com nome do aluno e valor", "quantos participantes cada turma tem", "planos com mais contratos ativos", "boletos pendentes com nome do usuario", # Financial "rescisoes pendentes com nome e valor", "saldo total das wallets", "renegociacoes do mes atual", "transferencias pendentes", # Discovery (MUST use information_schema) "quais colunas tem a tabela eventos", "mostra o schema completo do banco", "foreign keys entre as tabelas", "nao conheco essa tabela", # Novel queries (not in dataset) "quantos alunos inaptos financeiramente", "cartoes de credito ativos por bandeira", "valor medio dos planos por categoria", "campanhas de desconto vigentes", ] for q in hard_queries: prompt = f"<|im_start|>system\nYou are a command adapter. Output ONLY valid JSON. No explanation.<|im_end|>\n<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant\n" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) t0 = time.time() with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=200, temperature=0.1, do_sample=True, pad_token_id=tokenizer.eos_token_id) elapsed = time.time() - t0 response = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip() # Truncate to first valid JSON if response.count('}') > 0: response = response[:response.index('}') + 1] print(f"[{elapsed:.1f}s] {q}") print(f" > {response}") print()