DanielRegaladoCardoso commited on
Commit
730b25d
·
verified ·
1 Parent(s): 7a057d8

Load LoRA via PeftModel on top of standard base models (fixes r=16 vs r=8 mismatch)

Browse files
Files changed (1) hide show
  1. src/models/sql_generator.py +17 -20
src/models/sql_generator.py CHANGED
@@ -1,8 +1,6 @@
1
  """
2
- SQL Generator: text-to-SQL via the trained Qwen2.5-Coder-7B LoRA.
3
-
4
- Loads at module import time (root level), as required by ZeroGPU best
5
- practices. Inference happens inside @spaces.GPU in the orchestrator.
6
  """
7
 
8
  import logging
@@ -11,6 +9,7 @@ from typing import Optional
11
 
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -21,33 +20,31 @@ SYSTEM_PROMPT = (
21
  "Return only the SQL."
22
  )
23
 
24
- DEFAULT_MODEL = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"
 
25
 
26
 
27
  class SQLGenerator:
28
- """Text-to-SQL generator. Model loaded at construction time onto CUDA."""
29
 
30
- def __init__(
31
- self,
32
- hf_model: str = DEFAULT_MODEL,
33
- temperature: float = 0.0,
34
- max_new_tokens: int = 400,
35
- ) -> None:
36
- self.hf_model = hf_model
37
  self.temperature = temperature
38
  self.max_new_tokens = max_new_tokens
39
 
40
- logger.info(f"Loading SQL generator at module level: {self.hf_model}")
41
- self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
42
- # On ZeroGPU, device_map='cuda' uses emulation mode at module load and
43
- # real GPU inside @spaces.GPU calls.
44
- self.model = AutoModelForCausalLM.from_pretrained(
45
- self.hf_model,
46
  torch_dtype=torch.bfloat16,
47
  device_map="cuda",
48
  )
 
 
 
 
 
 
49
  self.model.eval()
50
- logger.info("SQL generator ready")
51
 
52
  def generate(self, question: str, schema: str) -> str:
53
  user_content = f"### Schema\n{schema}\n\n### Question\n{question}"
 
1
  """
2
+ SQL Generator: load the trained LoRA adapter on top of the standard Qwen
3
+ 2.5 Coder 7B base. Loaded at module level per ZeroGPU best practice.
 
 
4
  """
5
 
6
  import logging
 
9
 
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from peft import PeftModel
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
20
  "Return only the SQL."
21
  )
22
 
23
+ BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
24
+ ADAPTER_REPO = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"
25
 
26
 
27
  class SQLGenerator:
 
28
 
29
+ def __init__(self, temperature: float = 0.0, max_new_tokens: int = 400) -> None:
 
 
 
 
 
 
30
  self.temperature = temperature
31
  self.max_new_tokens = max_new_tokens
32
 
33
+ logger.info(f"Loading SQL base: {BASE_MODEL}")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
35
+ base = AutoModelForCausalLM.from_pretrained(
36
+ BASE_MODEL,
 
 
37
  torch_dtype=torch.bfloat16,
38
  device_map="cuda",
39
  )
40
+ logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}")
41
+ self.model = PeftModel.from_pretrained(
42
+ base,
43
+ ADAPTER_REPO,
44
+ torch_dtype=torch.bfloat16,
45
+ )
46
  self.model.eval()
47
+ logger.info("SQL generator ready (LoRA applied on Qwen base)")
48
 
49
  def generate(self, question: str, schema: str) -> str:
50
  user_content = f"### Schema\n{schema}\n\n### Question\n{question}"