apsora commited on
Commit
b6c29b8
·
verified ·
1 Parent(s): c5818d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -98
app.py CHANGED
@@ -1,56 +1,26 @@
1
  import json
2
  import re
 
 
3
  from dataclasses import dataclass, field
4
  from typing import Any, Dict, List, Optional, Tuple
5
 
6
  import numpy as np
7
- import torch
8
  from PIL import Image
9
  import gradio as gr
10
- from transformers import AutoProcessor, AutoModelForVision2Seq
11
- import spaces # <-- needed for Stateless GPU / zeroGPU
12
-
13
- # ---------------------------------------------------------------------
14
- # Minimal GPU-decorated function so Stateless GPU doesn't error out
15
- # ---------------------------------------------------------------------
16
- @spaces.GPU
17
- def gpu_ping() -> str:
18
- """
19
- Dummy GPU endpoint so Hugging Face Stateless GPU / zeroGPU
20
- detects at least one @spaces.GPU function.
21
-
22
- We don't actually use this in the app logic. It just keeps
23
- the Space from throwing:
24
- 'No @spaces.GPU function detected during startup'.
25
- """
26
- return "gpu_ready"
27
-
28
 
29
  # ============================================================
30
- # 0. Model + guidelines setup
31
  # ============================================================
32
 
33
- # NOTE: we keep everything on CPU here to avoid touching CUDA
34
- # in the main process (required for Stateless GPU).
35
- DEVICE = "cpu"
36
- DTYPE = torch.float32
37
-
38
  MODEL_NAME = "maryzhang/qwen3vl-guideline-lora-model"
39
 
40
- print(f"Loading unified vision+text model {MODEL_NAME} on {DEVICE}", flush=True)
41
-
42
- model_vlm = AutoModelForVision2Seq.from_pretrained(
43
- MODEL_NAME,
44
- dtype=DTYPE,
45
- trust_remote_code=True,
46
- )
47
- model_vlm.to(DEVICE)
48
- model_vlm.eval()
49
 
50
- processor_vlm = AutoProcessor.from_pretrained(
51
- MODEL_NAME,
52
- trust_remote_code=True,
53
- )
54
 
55
  GUIDELINES_PATH = "guidelines_final.json"
56
 
@@ -119,41 +89,43 @@ print(f"Loaded {len(ALL_GUIDELINES)} guidelines", flush=True)
119
 
120
 
121
  # ============================================================
122
- # 1. Core LLM helpers (text-only + vision)
123
  # ============================================================
124
 
125
  def run_text_llm(system_prompt: str, user_prompt: str, max_new_tokens: int = 768) -> str:
126
  """
127
- Use Qwen3-VL (LoRA) in text-only mode.
 
 
 
128
  """
129
  messages = [
130
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
131
- {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
132
  ]
133
- prompt_text = processor_vlm.apply_chat_template(
134
- messages,
135
- tokenize=False,
136
- add_generation_prompt=True,
 
 
137
  )
 
 
 
138
 
139
- inputs = processor_vlm(
140
- text=prompt_text,
141
- return_tensors="pt",
142
- ).to(DEVICE)
143
-
144
- with torch.no_grad():
145
- output_ids = model_vlm.generate(
146
- **inputs,
147
- max_new_tokens=max_new_tokens,
148
- temperature=0.0,
149
- do_sample=False,
150
- )
151
 
152
- generated = processor_vlm.decode(
153
- output_ids[0],
154
- skip_special_tokens=True,
155
- ).strip()
156
- return generated
 
 
 
 
 
 
157
 
158
 
159
  def vlm_generate_json_from_images(
@@ -161,48 +133,65 @@ def vlm_generate_json_from_images(
161
  images: List[Image.Image],
162
  ) -> Dict[str, Any]:
163
  """
164
- Call Qwen3-VL with images and ask it to return STRICT JSON.
 
 
 
 
 
165
  """
166
  if not images:
167
  images = [Image.new("RGB", (64, 64), "white")]
168
 
169
- content = [{"type": "image"} for _ in images]
170
- content.append({"type": "text", "text": prompt})
171
-
172
- messages = [{"role": "user", "content": content}]
 
 
 
 
 
 
173
 
174
- prompt_text = processor_vlm.apply_chat_template(
175
- messages,
176
- tokenize=False,
177
- add_generation_prompt=True,
 
178
  )
179
 
180
- inputs = processor_vlm(
181
- text=prompt_text,
182
- images=images,
183
- return_tensors="pt",
184
- ).to(DEVICE)
185
-
186
- with torch.no_grad():
187
- output_ids = model_vlm.generate(
188
- **inputs,
189
- max_new_tokens=512,
190
- temperature=0.0,
191
- do_sample=False,
192
- )
193
 
194
- generated = processor_vlm.decode(
195
- output_ids[0],
196
- skip_special_tokens=True,
197
- ).strip()
 
 
 
 
 
 
 
198
 
199
- m = re.search(r"\{.*\}", generated, re.DOTALL)
 
200
  if m:
201
  try:
202
  return json.loads(m.group(0))
203
  except Exception:
204
  pass
205
- return {"parse_error": True, "raw": generated}
206
 
207
 
208
  # ============================================================
@@ -287,9 +276,7 @@ def rag_retrieve(query: str, top_k: int = 6) -> List[Dict[str, Any]]:
287
  scored = []
288
  for g in ALL_GUIDELINES:
289
  pfl = g.get("pass_fail_logic") or {}
290
- pfl_text = " ".join(
291
- f"{k}: {v}" for k, v in pfl.items()
292
- )
293
  blob = " ".join(
294
  [
295
  g.get("topic", ""),
@@ -306,9 +293,7 @@ def rag_retrieve(query: str, top_k: int = 6) -> List[Dict[str, Any]]:
306
  hits = []
307
  for score, g in scored[:top_k]:
308
  pfl = g.get("pass_fail_logic") or {}
309
- pfl_text = " ".join(
310
- f"{k}: {v}" for k, v in pfl.items()
311
- )
312
  text = (
313
  " ".join(g.get("evaluation_criteria", []) or [])
314
  or " ".join(g.get("expected_answers", []) or [])
@@ -1081,7 +1066,9 @@ with gr.Blocks(title="DFM / GD&T Manufacturability Tutor") as demo:
1081
  2. *(Optional)* Add a short description of the part
1082
  3. Click **Start review**
1083
  4. Answer a few focused questions → get a guideline-by-guideline summary
1084
- This tool is meant to feel like a mini design review with a friendly TA.
 
 
1085
  """
1086
  )
1087
 
 
1
  import json
2
  import re
3
+ import base64
4
+ import io
5
  from dataclasses import dataclass, field
6
  from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import numpy as np
 
9
  from PIL import Image
10
  import gradio as gr
11
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # ============================================================
14
+ # 0. Model + guidelines setup (Inference API version)
15
  # ============================================================
16
 
 
 
 
 
 
17
  MODEL_NAME = "maryzhang/qwen3vl-guideline-lora-model"
18
 
19
+ print(f"Using hosted model via Inference API: {MODEL_NAME}", flush=True)
 
 
 
 
 
 
 
 
20
 
21
+ # This uses the HF Inference API (no local weights, no GPU in the Space)
22
+ # If the model is private, set HF_TOKEN as an environment variable in the Space.
23
+ hf_client = InferenceClient(MODEL_NAME)
 
24
 
25
  GUIDELINES_PATH = "guidelines_final.json"
26
 
 
89
 
90
 
91
  # ============================================================
92
+ # 1. Core LLM helpers (text-only + vision via Inference API)
93
  # ============================================================
94
 
95
  def run_text_llm(system_prompt: str, user_prompt: str, max_new_tokens: int = 768) -> str:
96
  """
97
+ Use the hosted Qwen3-VL model in text-only mode via chat_completion.
98
+
99
+ We build a simple system+user messages list and ask for a deterministic
100
+ response (temperature=0).
101
  """
102
  messages = [
103
+ {"role": "system", "content": system_prompt},
104
+ {"role": "user", "content": user_prompt},
105
  ]
106
+
107
+ response = hf_client.chat_completion(
108
+ messages=messages,
109
+ max_tokens=max_new_tokens,
110
+ temperature=0.0,
111
+ stream=False,
112
  )
113
+ # HuggingFace InferenceClient returns a ChatCompletionOutput
114
+ text = response.choices[0].message.content
115
+ return (text or "").strip()
116
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ def _pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str:
119
+ """
120
+ Convert a PIL image to a data URL (base64-encoded), which matches the
121
+ format expected by chat_completion with vision support:
122
+ type: "image_url", image_url: {"url": "data:image/png;base64,..."}
123
+ """
124
+ buf = io.BytesIO()
125
+ img.save(buf, format=fmt)
126
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
127
+ mime = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
128
+ return f"data:{mime};base64,{b64}"
129
 
130
 
131
  def vlm_generate_json_from_images(
 
133
  images: List[Image.Image],
134
  ) -> Dict[str, Any]:
135
  """
136
+ Call the hosted Qwen3-VL model with images + text using chat_completion.
137
+ We ask it to return STRICT JSON and then parse the JSON out of the reply.
138
+
139
+ This assumes the model supports OpenAI-style multimodal messages where
140
+ each content item can be {"type": "image_url", "image_url": {"url": ...}}
141
+ plus a text chunk.
142
  """
143
  if not images:
144
  images = [Image.new("RGB", (64, 64), "white")]
145
 
146
+ # Build message content with multiple images + prompt text
147
+ content: List[Dict[str, Any]] = []
148
+ for img in images:
149
+ url = _pil_to_data_url(img)
150
+ content.append(
151
+ {
152
+ "type": "image_url",
153
+ "image_url": {"url": url},
154
+ }
155
+ )
156
 
157
+ content.append(
158
+ {
159
+ "type": "text",
160
+ "text": prompt,
161
+ }
162
  )
163
 
164
+ messages = [
165
+ {
166
+ "role": "system",
167
+ "content": "You are a vision model that ONLY replies with strict JSON.",
168
+ },
169
+ {
170
+ "role": "user",
171
+ "content": content,
172
+ },
173
+ ]
 
 
 
174
 
175
+ # Ask for a deterministic, non-streaming, JSON-like answer
176
+ response = hf_client.chat_completion(
177
+ messages=messages,
178
+ max_tokens=512,
179
+ temperature=0.0,
180
+ stream=False,
181
+ # If your model supports response_format, you can uncomment:
182
+ # response_format={"type": "json_object"},
183
+ )
184
+ raw = response.choices[0].message.content or ""
185
+ raw = raw.strip()
186
 
187
+ # Try to extract JSON object from the raw string
188
+ m = re.search(r"\{.*\}", raw, re.DOTALL)
189
  if m:
190
  try:
191
  return json.loads(m.group(0))
192
  except Exception:
193
  pass
194
+ return {"parse_error": True, "raw": raw}
195
 
196
 
197
  # ============================================================
 
276
  scored = []
277
  for g in ALL_GUIDELINES:
278
  pfl = g.get("pass_fail_logic") or {}
279
+ pfl_text = " ".join(f"{k}: {v}" for k, v in pfl.items())
 
 
280
  blob = " ".join(
281
  [
282
  g.get("topic", ""),
 
293
  hits = []
294
  for score, g in scored[:top_k]:
295
  pfl = g.get("pass_fail_logic") or {}
296
+ pfl_text = " ".join(f"{k}: {v}" for k, v in pfl.items())
 
 
297
  text = (
298
  " ".join(g.get("evaluation_criteria", []) or [])
299
  or " ".join(g.get("expected_answers", []) or [])
 
1066
  2. *(Optional)* Add a short description of the part
1067
  3. Click **Start review**
1068
  4. Answer a few focused questions → get a guideline-by-guideline summary
1069
+
1070
+ This tool is powered by a hosted multimodal model via the Hugging Face Inference API,
1071
+ so it runs on free CPU hardware without loading big weights in this Space.
1072
  """
1073
  )
1074