DanielRegaladoCardoso commited on
Commit
c2ac226
·
verified ·
1 Parent(s): f75f54b

ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. src/models/sql_generator.py +15 -30
src/models/sql_generator.py CHANGED
@@ -1,16 +1,16 @@
1
  """
2
  SQL Generator: text-to-SQL via the trained Qwen2.5-Coder-7B LoRA.
3
 
4
- Loads the merged 16-bit checkpoint from Hugging Face Hub by default.
5
- The same `DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora` repo
6
- contains both the LoRA adapter and the merged model.
7
  """
8
 
9
  import logging
10
  import re
11
  from typing import Optional
12
 
13
- from src.models.base import BaseModel
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -21,11 +21,11 @@ SYSTEM_PROMPT = (
21
  "Return only the SQL."
22
  )
23
 
 
24
 
25
- class SQLGenerator(BaseModel):
26
- """Text-to-SQL generator using the trained Qwen2.5-Coder-7B model."""
27
 
28
- DEFAULT_MODEL = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"
 
29
 
30
  def __init__(
31
  self,
@@ -33,33 +33,23 @@ class SQLGenerator(BaseModel):
33
  temperature: float = 0.0,
34
  max_new_tokens: int = 400,
35
  ) -> None:
36
- super().__init__(model_name="sql-generator")
37
  self.hf_model = hf_model
38
  self.temperature = temperature
39
  self.max_new_tokens = max_new_tokens
40
 
41
- def load(self) -> None:
42
- from transformers import AutoModelForCausalLM, AutoTokenizer
43
- import torch
44
-
45
- logger.info(f"Loading SQL generator: {self.hf_model}")
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
48
-
49
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
 
 
50
  self.model = AutoModelForCausalLM.from_pretrained(
51
  self.hf_model,
52
- torch_dtype=dtype,
53
- device_map=device,
54
  )
55
  self.model.eval()
56
- self.is_loaded = True
57
- logger.info(f"SQL generator loaded on {device}")
58
-
59
- def generate(self, question: str, schema: str) -> str: # type: ignore[override]
60
- self._validate_loaded()
61
- import torch
62
 
 
63
  user_content = f"### Schema\n{schema}\n\n### Question\n{question}"
64
  messages = [
65
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -81,18 +71,13 @@ class SQLGenerator(BaseModel):
81
  text = self.tokenizer.decode(
82
  out[0][input_ids.shape[1]:], skip_special_tokens=True
83
  )
84
- sql = self._clean_sql(text)
85
- logger.info(f"Generated SQL: {sql[:100]}...")
86
- return sql
87
 
88
  @staticmethod
89
  def _clean_sql(text: str) -> str:
90
- """Strip code fences, trailing prose, ensure a single SQL statement."""
91
  text = text.strip()
92
- # Strip fences ```sql ... ```
93
  text = re.sub(r"^```(?:sql)?\s*", "", text, flags=re.IGNORECASE)
94
  text = re.sub(r"\s*```\s*$", "", text)
95
- # Cut at the first ; if followed by prose
96
  if ";" in text:
97
  stmt, _, _ = text.partition(";")
98
  text = stmt + ";"
 
1
  """
2
  SQL Generator: text-to-SQL via the trained Qwen2.5-Coder-7B LoRA.
3
 
4
+ Loads at module import time (root level), as required by ZeroGPU best
5
+ practices. Inference happens inside @spaces.GPU in the orchestrator.
 
6
  """
7
 
8
  import logging
9
  import re
10
  from typing import Optional
11
 
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
21
  "Return only the SQL."
22
  )
23
 
24
+ DEFAULT_MODEL = "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora"
25
 
 
 
26
 
27
+ class SQLGenerator:
28
+ """Text-to-SQL generator. Model loaded at construction time onto CUDA."""
29
 
30
  def __init__(
31
  self,
 
33
  temperature: float = 0.0,
34
  max_new_tokens: int = 400,
35
  ) -> None:
 
36
  self.hf_model = hf_model
37
  self.temperature = temperature
38
  self.max_new_tokens = max_new_tokens
39
 
40
+ logger.info(f"Loading SQL generator at module level: {self.hf_model}")
 
 
 
 
 
 
 
41
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
42
+ # On ZeroGPU, device_map='cuda' uses emulation mode at module load and
43
+ # real GPU inside @spaces.GPU calls.
44
  self.model = AutoModelForCausalLM.from_pretrained(
45
  self.hf_model,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="cuda",
48
  )
49
  self.model.eval()
50
+ logger.info("SQL generator ready")
 
 
 
 
 
51
 
52
+ def generate(self, question: str, schema: str) -> str:
53
  user_content = f"### Schema\n{schema}\n\n### Question\n{question}"
54
  messages = [
55
  {"role": "system", "content": SYSTEM_PROMPT},
 
71
  text = self.tokenizer.decode(
72
  out[0][input_ids.shape[1]:], skip_special_tokens=True
73
  )
74
+ return self._clean_sql(text)
 
 
75
 
76
  @staticmethod
77
  def _clean_sql(text: str) -> str:
 
78
  text = text.strip()
 
79
  text = re.sub(r"^```(?:sql)?\s*", "", text, flags=re.IGNORECASE)
80
  text = re.sub(r"\s*```\s*$", "", text)
 
81
  if ";" in text:
82
  stmt, _, _ = text.partition(";")
83
  text = stmt + ";"