Priyansh Saxena commited on
Commit
cb6c215
Β·
1 Parent(s): f3fd40f

feat: download Qwen2.5-Coder-0.5B + BART at build, add few-shot prompts

Browse files
Files changed (3) hide show
  1. Dockerfile +16 -3
  2. app.py +6 -6
  3. llm_agent.py +64 -52
Dockerfile CHANGED
@@ -2,15 +2,28 @@ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
5
- COPY . .
6
 
7
  RUN pip install --no-cache-dir --upgrade pip && \
8
  pip install --no-cache-dir -r requirements.txt
9
 
10
- RUN mkdir -p /app/data/uploads /app/static/images
11
-
12
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
13
  ENV HF_HOME=/app/.cache/huggingface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ENV HF_HUB_OFFLINE=1
15
  ENV TRANSFORMERS_OFFLINE=1
16
  ENV HF_HUB_DISABLE_TELEMETRY=1
 
2
 
3
  WORKDIR /app
4
 
5
+ COPY requirements.txt .
6
 
7
  RUN pip install --no-cache-dir --upgrade pip && \
8
  pip install --no-cache-dir -r requirements.txt
9
 
10
+ # Pre-download models during build so runtime stays offline
 
11
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
12
  ENV HF_HOME=/app/.cache/huggingface
13
+
14
+ RUN python -c "\
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM; \
16
+ AutoTokenizer.from_pretrained('ArchCoder/fine-tuned-bart-large'); \
17
+ AutoModelForSeq2SeqLM.from_pretrained('ArchCoder/fine-tuned-bart-large'); \
18
+ AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-0.5B-Instruct'); \
19
+ AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-Coder-0.5B-Instruct'); \
20
+ print('Models downloaded successfully')"
21
+
22
+ COPY . .
23
+
24
+ RUN mkdir -p /app/data/uploads /app/static/images
25
+
26
+ # Lock to offline at runtime β€” all models are already cached
27
  ENV HF_HUB_OFFLINE=1
28
  ENV TRANSFORMERS_OFFLINE=1
29
  ENV HF_HUB_DISABLE_TELEMETRY=1
app.py CHANGED
@@ -54,12 +54,12 @@ def index():
54
  def models():
55
  return jsonify({
56
  "models": [
57
- {"id": "qwen", "name": "Qwen2.5-1.5B", "provider": "Local (optional path)", "free": True},
58
- {"id": "bart", "name": "BART (fine-tuned)", "provider": "Local (transformers)", "free": True},
59
- {"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI (API key)", "free": False},
60
- {"id": "grok", "name": "Grok-3 Mini", "provider": "xAI (API key)", "free": False},
61
  ],
62
- "default": "bart"
63
  })
64
 
65
 
@@ -70,7 +70,7 @@ def plot():
70
  if not data or not data.get('query'):
71
  return jsonify({'error': 'Missing required field: query'}), 400
72
 
73
- logging.info(f"Plot request: model={data.get('model','bart')} query={data.get('query')[:80]}")
74
  result = agent.process_request(data)
75
  logging.info(f"Plot completed in {time.time() - t0:.2f}s")
76
  return jsonify(result)
 
54
  def models():
55
  return jsonify({
56
  "models": [
57
+ {"id": "qwen", "name": "Qwen2.5-Coder-0.5B", "provider": "Local (transformers)", "free": True},
58
+ {"id": "bart", "name": "BART (fine-tuned)", "provider": "Local (transformers)", "free": True},
59
+ {"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI (API key)", "free": False},
60
+ {"id": "grok", "name": "Grok-3 Mini", "provider": "xAI (API key)", "free": False},
61
  ],
62
+ "default": "qwen"
63
  })
64
 
65
 
 
70
  if not data or not data.get('query'):
71
  return jsonify({'error': 'Missing required field: query'}), 400
72
 
73
+ logging.info(f"Plot request: model={data.get('model','qwen')} query={data.get('query')[:80]}")
74
  result = agent.process_request(data)
75
  logging.info(f"Plot completed in {time.time() - t0:.2f}s")
76
  return jsonify(result)
llm_agent.py CHANGED
@@ -15,50 +15,69 @@ load_dotenv()
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
18
 
19
- def _model_dir(dirname: str) -> str:
20
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
 
21
 
 
 
 
22
 
23
- def _has_model_weights(model_dir: str) -> bool:
24
- weight_files = (
25
- "pytorch_model.bin",
26
- "model.safetensors",
27
- "tf_model.h5",
28
- "flax_model.msgpack",
29
- )
30
- return os.path.isdir(model_dir) and any(
31
- os.path.exists(os.path.join(model_dir, filename)) for filename in weight_files
32
- )
33
 
34
- # ---------------------------------------------------------------------------
35
- # Prompt templates
36
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- _SYSTEM_PROMPT = (
39
- "You are a data visualization expert. "
40
- "Given the user request and the dataset schema provided, output ONLY a valid JSON "
41
- "object β€” no explanation, no markdown fences, no extra text.\n\n"
42
- "Required keys:\n"
43
- ' "x" : string β€” exact column name for the x-axis\n'
44
- ' "y" : array β€” one or more exact column names for the y-axis\n'
45
- ' "chart_type" : string β€” one of: line, bar, scatter, pie, histogram, box, area\n'
46
- ' "color" : string β€” optional CSS color, e.g. "red", "#4f8cff"\n\n'
47
- "Rules:\n"
48
- "- Use only column names that appear in the schema. Never invent names.\n"
49
- "- For pie: y must contain exactly one column.\n"
50
- "- For histogram/box: x may equal the first element of y.\n"
51
- "- Default to line if chart type is ambiguous."
52
- )
53
 
54
 
55
  def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
56
  schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
57
  samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3])
58
  return (
59
- f"Dataset columns:\n{schema}\n\n"
60
- f"Sample rows (first 3):\n{samples}\n"
61
- f"User request: {query}"
 
62
  )
63
 
64
 
@@ -198,22 +217,16 @@ class LLM_Agent:
198
  self._bart_model = None
199
  self._qwen_tokenizer = None
200
  self._qwen_model = None
201
- self._bart_model_dir = os.getenv("BART_LOCAL_PATH", _model_dir("fine-tuned-bart-large"))
202
- self._qwen_model_dir = os.getenv("QWEN_LOCAL_PATH", "")
203
 
204
  # -- model runners -------------------------------------------------------
205
 
206
  def _run_qwen(self, user_msg: str) -> str:
 
207
  if self._qwen_model is None:
208
  from transformers import AutoModelForCausalLM, AutoTokenizer
209
- model_id = self._qwen_model_dir
210
- if not model_id:
211
- raise ValueError("Qwen local model is not configured in this Space")
212
- if not _has_model_weights(model_id):
213
- raise ValueError(f"Qwen model weights not found in {model_id}")
214
- logger.info("Loading Qwen model (first request)...")
215
- self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
216
- self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
217
  logger.info("Qwen model loaded.")
218
  messages = [
219
  {"role": "system", "content": _SYSTEM_PROMPT},
@@ -260,14 +273,12 @@ class LLM_Agent:
260
  return resp.choices[0].message.content
261
 
262
  def _run_bart(self, query: str) -> str:
 
263
  if self._bart_model is None:
264
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
265
- model_id = self._bart_model_dir
266
- if not _has_model_weights(model_id):
267
- raise ValueError(f"BART model weights not found in {model_id}")
268
- logger.info("Loading BART model (first request)...")
269
- self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
270
- self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
271
  logger.info("BART model loaded.")
272
  inputs = self._bart_tokenizer(
273
  query, return_tensors="pt", max_length=512, truncation=True
@@ -281,7 +292,7 @@ class LLM_Agent:
281
  t0 = time.time()
282
  query = data.get("query", "")
283
  data_path = data.get("file_path")
284
- model = data.get("model", "bart")
285
 
286
  if data_path and os.path.exists(data_path):
287
  self.data_processor = DataProcessor(data_path)
@@ -303,14 +314,15 @@ class LLM_Agent:
303
  user_msg = _user_message(query, columns, dtypes, sample_rows)
304
  if model == "gemini": raw_text = self._run_gemini(user_msg)
305
  elif model == "grok": raw_text = self._run_grok(user_msg)
 
306
  elif model == "qwen":
307
  try:
308
  raw_text = self._run_qwen(user_msg)
309
  except Exception as qwen_exc:
310
- logger.warning(f"Qwen unavailable, falling back to BART: {qwen_exc}")
311
  raw_text = self._run_bart(query)
312
  else:
313
- raw_text = self._run_bart(query)
314
 
315
  logger.info(f"LLM [{model}] output: {raw_text}")
316
  parsed = _parse_output(raw_text)
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ # ---------------------------------------------------------------------------
19
+ # Model IDs (downloaded at Docker build, cached in HF_HOME)
20
+ # ---------------------------------------------------------------------------
21
+ QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct")
22
+ BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large")
23
 
24
+ # ---------------------------------------------------------------------------
25
+ # Prompt templates with few-shot examples
26
+ # ---------------------------------------------------------------------------
27
 
28
+ _SYSTEM_PROMPT = """\
29
+ You are a data visualization expert. Given the user request and dataset schema, \
30
+ output ONLY a valid JSON object. No explanation, no markdown fences, no extra text.
31
 
32
+ Required JSON keys:
33
+ "x" : string β€” exact column name for the x-axis
34
+ "y" : array β€” one or more exact column names for the y-axis
35
+ "chart_type" : string β€” one of: line, bar, scatter, pie, histogram, box, area
36
+ "color" : string or null β€” optional CSS color like "red", "#4f8cff"
 
 
 
 
 
37
 
38
+ Rules:
39
+ - Use ONLY column names from the schema. Never invent names.
40
+ - For pie charts: y must contain exactly one column.
41
+ - For histogram/box: x may equal the first element of y.
42
+ - Default to "line" if chart type is ambiguous.
43
+
44
+ ### Examples
45
+
46
+ Example 1:
47
+ Schema: Year (integer), Sales (float), Profit (float)
48
+ User: "plot sales over the years with a red line"
49
+ Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"}
50
+
51
+ Example 2:
52
+ Schema: Month (string), Revenue (float), Expenses (float)
53
+ User: "bar chart comparing revenue and expenses by month"
54
+ Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null}
55
+
56
+ Example 3:
57
+ Schema: Category (string), Count (integer)
58
+ User: "pie chart of count by category"
59
+ Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null}
60
+
61
+ Example 4:
62
+ Schema: Date (string), Temperature (float), Humidity (float)
63
+ User: "scatter plot of temperature vs humidity in blue"
64
+ Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"}
65
 
66
+ Example 5:
67
+ Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float)
68
+ User: "show me an area chart of sales and marketing expense over years"
69
+ Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null}
70
+ """
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str:
74
  schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns)
75
  samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3])
76
  return (
77
+ f"Schema:\n{schema}\n\n"
78
+ f"Sample rows:\n{samples}\n"
79
+ f"User: \"{query}\"\n"
80
+ f"Output:"
81
  )
82
 
83
 
 
217
  self._bart_model = None
218
  self._qwen_tokenizer = None
219
  self._qwen_model = None
 
 
220
 
221
  # -- model runners -------------------------------------------------------
222
 
223
  def _run_qwen(self, user_msg: str) -> str:
224
+ """Qwen2.5-Coder-0.5B-Instruct β€” fast structured-JSON generation."""
225
  if self._qwen_model is None:
226
  from transformers import AutoModelForCausalLM, AutoTokenizer
227
+ logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}")
228
+ self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID)
229
+ self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID)
 
 
 
 
 
230
  logger.info("Qwen model loaded.")
231
  messages = [
232
  {"role": "system", "content": _SYSTEM_PROMPT},
 
273
  return resp.choices[0].message.content
274
 
275
  def _run_bart(self, query: str) -> str:
276
+ """ArchCoder/fine-tuned-bart-large β€” lightweight Seq2Seq fallback."""
277
  if self._bart_model is None:
278
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
279
+ logger.info(f"Loading BART model: {BART_MODEL_ID}")
280
+ self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID)
281
+ self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID)
 
 
 
282
  logger.info("BART model loaded.")
283
  inputs = self._bart_tokenizer(
284
  query, return_tensors="pt", max_length=512, truncation=True
 
292
  t0 = time.time()
293
  query = data.get("query", "")
294
  data_path = data.get("file_path")
295
+ model = data.get("model", "qwen")
296
 
297
  if data_path and os.path.exists(data_path):
298
  self.data_processor = DataProcessor(data_path)
 
314
  user_msg = _user_message(query, columns, dtypes, sample_rows)
315
  if model == "gemini": raw_text = self._run_gemini(user_msg)
316
  elif model == "grok": raw_text = self._run_grok(user_msg)
317
+ elif model == "bart": raw_text = self._run_bart(query)
318
  elif model == "qwen":
319
  try:
320
  raw_text = self._run_qwen(user_msg)
321
  except Exception as qwen_exc:
322
+ logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}")
323
  raw_text = self._run_bart(query)
324
  else:
325
+ raw_text = self._run_qwen(user_msg)
326
 
327
  logger.info(f"LLM [{model}] output: {raw_text}")
328
  parsed = _parse_output(raw_text)