jacopo22295 commited on
Commit
7014e99
·
verified ·
1 Parent(s): 5ad96e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -59
app.py CHANGED
@@ -39,7 +39,6 @@ ZONES = [
39
  "Other / Not sure"
40
  ]
41
 
42
- # Carico un modello CPU residente (ZeroGPU non mantiene stato GPU tra invocazioni)
43
  def load_model_cpu():
44
  m = models.resnet50(weights=None)
45
  num_ftrs = m.fc.in_features
@@ -78,17 +77,16 @@ APP_FORCE_LANG = os.environ.get("APP_FORCE_LANG", "").strip()
78
  # ======================
79
 
80
  def predict_on_cpu(img_pil: Image.Image):
81
- x = transform(img_pil.convert("RGB")).unsqueeze(0) # CPU tensor
82
  with torch.no_grad():
83
  logits = model_cpu(x)
84
  probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
85
  idx = int(probs.argmax())
86
  return IDX2LABEL.get(idx, f"class_{idx}"), float(probs[idx])
87
 
88
- @spaces.GPU(duration=60) # richiesto da ZeroGPU
89
  def predict_on_gpu(img_pil: Image.Image):
90
  device = "cuda"
91
- # ZeroGPU non conserva stato: ricreo il modello su GPU ad ogni chiamata
92
  m = models.resnet50(weights=None)
93
  num_ftrs = m.fc.in_features
94
  m.fc = torch.nn.Linear(num_ftrs, len(IDX2LABEL))
@@ -108,7 +106,6 @@ def predict_image(image: Image.Image):
108
  if torch.cuda.is_available():
109
  return predict_on_gpu(image)
110
  except Exception:
111
- # se ZeroGPU o CUDA falliscono, ripiega su CPU
112
  pass
113
  return predict_on_cpu(image)
114
 
@@ -117,11 +114,6 @@ def predict_image(image: Image.Image):
117
  # ======================
118
 
119
  def call_assistant(label, confidence, zone, note, user_question, thread_id=None):
120
- """
121
- Chiama l'Assistant OpenAI. Se VECTOR_STORE_ID è valorizzato, collega il File Search al thread.
122
- Ritorna (reply, thread_id).
123
- """
124
- # crea o riusa thread, collegando il vector store se presente
125
  if not thread_id:
126
  if VECTOR_STORE_ID:
127
  thread = client.beta.threads.create(
@@ -136,8 +128,7 @@ Classification: {label} ({round(confidence*100,2)}%).
136
  Zone: {zone or "Not specified"}.
137
  User note: {note or "(none)"}.
138
  """
139
-
140
- user_payload = core_context + "\nUser question:\n" + (user_question or "Provide initial advisory based on classification and note.")
141
 
142
  client.beta.threads.messages.create(
143
  thread_id=thread_id,
@@ -150,12 +141,10 @@ User note: {note or "(none)"}.
150
 
151
  extra_instructions = (
152
  "Act as a PPG marine coatings technical specialist for ships (marine environments only). "
153
- "Answer ONLY using information found in the attached docs via File Search (TDS/SDS, standard cycles). "
154
- "If the docs lack details, reply 'Not in docs' and ask a targeted follow-up. "
155
- "ALWAYS ask for the area/zone if it's missing before prescribing a cycle. "
156
- "Structure: Diagnosis; Surface Preparation (ISO 8501-1, profile); System (primer/build/top or AF; DFT per coat; coats; recoat windows); Notes; Short disclaimer: "
157
- "'Research use only; verify with official PPG specs.' "
158
- "Cite file name and section/page when relevant. "
159
  "Provide first in English. " + second_lang_clause
160
  )
161
 
@@ -163,17 +152,14 @@ User note: {note or "(none)"}.
163
  thread_id=thread_id,
164
  assistant_id=ASSISTANT_ID,
165
  instructions=extra_instructions,
166
- # alcune versioni supportano "tool_choice": "file_search"
167
  )
168
 
169
- # polling semplice
170
  while True:
171
  r = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
172
  if r.status in ["completed", "failed", "cancelled", "expired"]:
173
  break
174
  time.sleep(0.6)
175
 
176
- # estrai ultimo messaggio assistant (testo)
177
  msgs = client.beta.threads.messages.list(thread_id=thread_id)
178
  reply = None
179
  for m in msgs.data:
@@ -187,52 +173,38 @@ User note: {note or "(none)"}.
187
  return reply or "No reply from Assistant.", thread_id
188
 
189
  # ======================
190
- # Pipelines (generator per streaming)
191
  # ======================
192
 
193
  def run_analysis(image, note, zone, chat_history, thread_state):
194
- """
195
- Generator: fa uno yield immediato per mostrare 'Analyzing...' prima del lavoro vero.
196
- Ritorna 2 volte:
197
- 1) stato intermedio (messaggio di lavorazione)
198
- 2) risultato finale
199
- """
200
  if image is None:
201
  yield "No image received.", chat_history, thread_state
202
  return
203
-
204
  if not zone or zone == "Other / Not sure":
205
- msg = "**Please select the area/zone first.** The assistant needs the zone to propose a correct cycle."
206
  yield msg, chat_history, thread_state
207
  return
208
 
209
- # stato intermedio visibile subito
210
- yield "**Analyzing image...** Running classification and querying the PPG Assistant.", chat_history, thread_state
211
 
212
- # classificazione + risposta initiale assistant
213
  label, conf = predict_image(image)
214
- reply, thread_id = call_assistant(label, conf, zone, note, "Provide initial advisory based on classification and note.")
215
 
216
  header = f"**Model result:** `{label}` — confidence **{round(conf*100,2)}%**\n\n"
217
  out_text = header + (reply or "")
218
-
219
  new_history = chat_history[:] if chat_history else []
220
  new_history.append(("", reply))
221
-
222
- # stato finale
223
  yield out_text, new_history, {"thread_id": thread_id, "label": label, "confidence": conf, "zone": zone or ""}
224
 
225
  def continue_chat(user_msg, chat_history, thread_state, note, zone):
226
- if not user_msg or not user_msg.strip():
227
  return chat_history, ""
228
-
229
  label = (thread_state or {}).get("label") or "unknown"
230
  conf = (thread_state or {}).get("confidence") or 0.0
231
  current_zone = zone or (thread_state or {}).get("zone") or "Not specified"
232
  thread_id = (thread_state or {}).get("thread_id")
233
-
234
  reply, thread_id = call_assistant(label, conf, current_zone, note, user_msg, thread_id)
235
-
236
  chat_history.append((user_msg, reply))
237
  thread_state["thread_id"] = thread_id
238
  return chat_history, ""
@@ -244,24 +216,19 @@ def continue_chat(user_msg, chat_history, thread_state, note, zone):
244
  WELCOME = """
245
  # Corrosion Assistant — Beta
246
 
247
- **Welcome!** This demo runs a custom **ResNet50 corrosion classifier** and connects to a dedicated **PPG Assistant** on OpenAI.
248
- - **Model**: ResNet50 classifier, **trained locally** on ~**9,000 images**
249
- - **Data collection**: a public link for contributing images will open **soon**
250
- - **Disclaimer**: research & experimental only. No professional advice, no warranty.
251
-
252
- After image analysis you can continue chatting with the assistant.
253
  """
254
 
255
  with gr.Blocks(title="Corrosion Assistant", theme=gr.themes.Soft()) as demo:
256
  gr.Markdown(WELCOME)
257
 
258
- # Mostra automaticamente la coda/stato run (Gradio 5)
259
- status = gr.Status()
260
-
261
  with gr.Row():
262
  with gr.Column(scale=2):
263
  img = gr.Image(type="pil", sources=["upload","webcam"], label="Upload or webcam")
264
- note = gr.Textbox(label="Notes / Context (optional)", placeholder="Write in your language.")
265
  zone = gr.Dropdown(choices=ZONES, label="Zone (indicative)", value="Other / Not sure")
266
  analyze_btn = gr.Button("Analyze image", variant="primary")
267
  with gr.Column(scale=3):
@@ -271,22 +238,19 @@ with gr.Blocks(title="Corrosion Assistant", theme=gr.themes.Soft()) as demo:
271
 
272
  with gr.Row():
273
  with gr.Column(scale=3):
274
- # fix warning: specifica type="tuples"
275
  chat = gr.Chatbot(height=320, label="Advisor chat", type="tuples")
276
- chat_in = gr.Textbox(label="Your message", placeholder="Ask about prep, products, recoats, etc.")
277
  send_btn = gr.Button("Send")
278
  clear_btn = gr.Button("Clear chat")
279
  with gr.Column(scale=2):
280
  gr.Markdown(
281
- "> **Disclaimer:** Research & experimental use only. Validate with certified PPG specs and local regulations. "
282
- "No professional advice is provided and no responsibility is assumed."
283
  )
284
 
285
- # States
286
  chat_state = gr.State([])
287
  thread_state = gr.State({"thread_id": None, "label": None, "confidence": 0.0, "zone": ""})
288
 
289
- # Analyze: funzione GENERATOR con streaming (yield)
290
  analyze_btn.click(
291
  fn=run_analysis,
292
  inputs=[img, note, zone, chat_state, thread_state],
@@ -297,7 +261,6 @@ with gr.Blocks(title="Corrosion Assistant", theme=gr.themes.Soft()) as demo:
297
  outputs=[chat]
298
  )
299
 
300
- # Chat
301
  send_btn.click(
302
  fn=continue_chat,
303
  inputs=[chat_in, chat_state, thread_state, note, zone],
@@ -313,5 +276,5 @@ with gr.Blocks(title="Corrosion Assistant", theme=gr.themes.Soft()) as demo:
313
  demo.api_mode = "enabled"
314
 
315
  if __name__ == "__main__":
316
- # SSR attivo di default in Gradio 5; nessun share link qui
317
  demo.launch()
 
 
39
  "Other / Not sure"
40
  ]
41
 
 
42
  def load_model_cpu():
43
  m = models.resnet50(weights=None)
44
  num_ftrs = m.fc.in_features
 
77
  # ======================
78
 
79
  def predict_on_cpu(img_pil: Image.Image):
80
+ x = transform(img_pil.convert("RGB")).unsqueeze(0)
81
  with torch.no_grad():
82
  logits = model_cpu(x)
83
  probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
84
  idx = int(probs.argmax())
85
  return IDX2LABEL.get(idx, f"class_{idx}"), float(probs[idx])
86
 
87
+ @spaces.GPU(duration=60)
88
  def predict_on_gpu(img_pil: Image.Image):
89
  device = "cuda"
 
90
  m = models.resnet50(weights=None)
91
  num_ftrs = m.fc.in_features
92
  m.fc = torch.nn.Linear(num_ftrs, len(IDX2LABEL))
 
106
  if torch.cuda.is_available():
107
  return predict_on_gpu(image)
108
  except Exception:
 
109
  pass
110
  return predict_on_cpu(image)
111
 
 
114
  # ======================
115
 
116
  def call_assistant(label, confidence, zone, note, user_question, thread_id=None):
 
 
 
 
 
117
  if not thread_id:
118
  if VECTOR_STORE_ID:
119
  thread = client.beta.threads.create(
 
128
  Zone: {zone or "Not specified"}.
129
  User note: {note or "(none)"}.
130
  """
131
+ user_payload = core_context + "\nUser question:\n" + (user_question or "Provide initial advisory.")
 
132
 
133
  client.beta.threads.messages.create(
134
  thread_id=thread_id,
 
141
 
142
  extra_instructions = (
143
  "Act as a PPG marine coatings technical specialist for ships (marine environments only). "
144
+ "Answer ONLY using information found in the attached docs via File Search. "
145
+ "If docs lack details, reply 'Not in docs'. "
146
+ "ALWAYS ask for the zone if missing before prescribing. "
147
+ "Structure: Diagnosis; Surface Preparation; System; Notes; Disclaimer. "
 
 
148
  "Provide first in English. " + second_lang_clause
149
  )
150
 
 
152
  thread_id=thread_id,
153
  assistant_id=ASSISTANT_ID,
154
  instructions=extra_instructions,
 
155
  )
156
 
 
157
  while True:
158
  r = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
159
  if r.status in ["completed", "failed", "cancelled", "expired"]:
160
  break
161
  time.sleep(0.6)
162
 
 
163
  msgs = client.beta.threads.messages.list(thread_id=thread_id)
164
  reply = None
165
  for m in msgs.data:
 
173
  return reply or "No reply from Assistant.", thread_id
174
 
175
  # ======================
176
+ # Pipelines (generator)
177
  # ======================
178
 
179
  def run_analysis(image, note, zone, chat_history, thread_state):
 
 
 
 
 
 
180
  if image is None:
181
  yield "No image received.", chat_history, thread_state
182
  return
 
183
  if not zone or zone == "Other / Not sure":
184
+ msg = "**Please select the area/zone first.**"
185
  yield msg, chat_history, thread_state
186
  return
187
 
188
+ # messaggio intermedio
189
+ yield "**Analyzing image...** Please wait.", chat_history, thread_state
190
 
 
191
  label, conf = predict_image(image)
192
+ reply, thread_id = call_assistant(label, conf, zone, note, "Provide initial advisory.")
193
 
194
  header = f"**Model result:** `{label}` — confidence **{round(conf*100,2)}%**\n\n"
195
  out_text = header + (reply or "")
 
196
  new_history = chat_history[:] if chat_history else []
197
  new_history.append(("", reply))
 
 
198
  yield out_text, new_history, {"thread_id": thread_id, "label": label, "confidence": conf, "zone": zone or ""}
199
 
200
  def continue_chat(user_msg, chat_history, thread_state, note, zone):
201
+ if not user_msg.strip():
202
  return chat_history, ""
 
203
  label = (thread_state or {}).get("label") or "unknown"
204
  conf = (thread_state or {}).get("confidence") or 0.0
205
  current_zone = zone or (thread_state or {}).get("zone") or "Not specified"
206
  thread_id = (thread_state or {}).get("thread_id")
 
207
  reply, thread_id = call_assistant(label, conf, current_zone, note, user_msg, thread_id)
 
208
  chat_history.append((user_msg, reply))
209
  thread_state["thread_id"] = thread_id
210
  return chat_history, ""
 
216
  WELCOME = """
217
  # Corrosion Assistant — Beta
218
 
219
+ **Welcome!**
220
+ ResNet50 classifier trained locally on ~9,000 images.
221
+ Data collection link coming soon.
222
+ **Disclaimer**: research & experimental only.
 
 
223
  """
224
 
225
  with gr.Blocks(title="Corrosion Assistant", theme=gr.themes.Soft()) as demo:
226
  gr.Markdown(WELCOME)
227
 
 
 
 
228
  with gr.Row():
229
  with gr.Column(scale=2):
230
  img = gr.Image(type="pil", sources=["upload","webcam"], label="Upload or webcam")
231
+ note = gr.Textbox(label="Notes / Context (optional)")
232
  zone = gr.Dropdown(choices=ZONES, label="Zone (indicative)", value="Other / Not sure")
233
  analyze_btn = gr.Button("Analyze image", variant="primary")
234
  with gr.Column(scale=3):
 
238
 
239
  with gr.Row():
240
  with gr.Column(scale=3):
 
241
  chat = gr.Chatbot(height=320, label="Advisor chat", type="tuples")
242
+ chat_in = gr.Textbox(label="Your message")
243
  send_btn = gr.Button("Send")
244
  clear_btn = gr.Button("Clear chat")
245
  with gr.Column(scale=2):
246
  gr.Markdown(
247
+ "> **Disclaimer:** Research & experimental use only. Validate with official PPG specs. "
248
+ "No professional advice or responsibility assumed."
249
  )
250
 
 
251
  chat_state = gr.State([])
252
  thread_state = gr.State({"thread_id": None, "label": None, "confidence": 0.0, "zone": ""})
253
 
 
254
  analyze_btn.click(
255
  fn=run_analysis,
256
  inputs=[img, note, zone, chat_state, thread_state],
 
261
  outputs=[chat]
262
  )
263
 
 
264
  send_btn.click(
265
  fn=continue_chat,
266
  inputs=[chat_in, chat_state, thread_state, note, zone],
 
276
  demo.api_mode = "enabled"
277
 
278
  if __name__ == "__main__":
 
279
  demo.launch()
280
+