ALVHB95 commited on
Commit
967e5a0
·
1 Parent(s): 990e0b1
Files changed (1) hide show
  1. app.py +87 -82
app.py CHANGED
@@ -1,28 +1,35 @@
1
  """
2
  =========================================================
3
- Fixed app.py — Green Greta (Gradio + HF + LangChain v0.2)
4
- Notes:
5
- - Uses HuggingFaceEndpoint with a public router model (Zephyr) by default.
6
- - Robust JSON parsing (no fragile string slicing).
7
- - EfficientNet input size fixed (224x224).
8
- - LangChain v0.2 import layout (core/community/text-splitters).
9
- - Safer memory for ConversationalRetrievalChain; better error handling.
10
- - Gradio binds to 0.0.0.0:7860 for Docker.
11
  =========================================================
12
  """
13
 
 
 
 
14
  import os
15
  import json
16
  import shutil
17
 
 
18
  import gradio as gr
 
 
19
  import tensorflow as tf
 
20
  from PIL import Image
21
 
22
- import tenacity # for retrying failed requests
 
23
  from fake_useragent import UserAgent
24
 
25
- # LangChain (v0.2+ layout)
26
  from langchain_text_splitters import RecursiveCharacterTextSplitter
27
  from langchain_core.prompts import ChatPromptTemplate
28
  from langchain_core.output_parsers import PydanticOutputParser
@@ -33,43 +40,60 @@ from langchain_community.vectorstores import Chroma
33
  from langchain.chains import ConversationalRetrievalChain
34
  from langchain.memory import ConversationBufferMemory
35
 
 
36
  from pydantic.v1 import BaseModel, Field
37
 
38
- # Theming
 
 
 
39
  import theme
 
 
 
 
 
 
40
  theme = theme.Theme()
41
 
42
- # Import URL list
43
- from url_list import URLS
44
 
45
- """
46
- =========================================================
47
- 1) IMAGE CLASSIFICATION MODEL SETUP
48
- =========================================================
49
- """
50
- from huggingface_hub import from_pretrained_keras
51
 
52
- # Load a Keras model from HuggingFace Hub
53
- model1 = from_pretrained_keras("rocioadlc/efficientnetB0_trash")
54
 
55
- # Define class labels for the trash classification
 
 
 
56
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
57
 
58
 
59
  def predict_image(input_image: Image.Image):
60
  """
61
  Resize the user-uploaded image and preprocess it for EfficientNetB0.
62
- Returns a dict of class probabilities.
63
  """
64
- # Correct size for EfficientNetB0 is 224x224
65
- image_array = tf.keras.preprocessing.image.img_to_array(
66
- input_image.resize((224, 224))
67
- )
68
  image_array = tf.keras.applications.efficientnet.preprocess_input(image_array)
69
- image_array = tf.expand_dims(image_array, 0) # batch dim
70
-
71
- predictions = model1.predict(image_array)
72
- probs = predictions[0].tolist()
 
 
 
 
 
 
 
 
73
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
74
 
75
 
@@ -81,12 +105,10 @@ image_gradio_app = gr.Interface(
81
  theme=theme,
82
  )
83
 
84
- """
85
- =========================================================
86
- 2) KNOWLEDGE LOADING (RAG)
87
- =========================================================
88
- """
89
- # 2.1) Define user agent to avoid blocking, etc.
90
  user_agent = UserAgent().random
91
  header_template = {"User-Agent": user_agent}
92
 
@@ -104,7 +126,6 @@ def safe_load_all_urls(urls):
104
  docs = load_url(link)
105
  all_docs.extend(docs)
106
  except Exception as e:
107
- # If load_url fails after all retries, skip that URL
108
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
109
  return all_docs
110
 
@@ -119,9 +140,10 @@ text_splitter = RecursiveCharacterTextSplitter(
119
 
120
  docs = text_splitter.split_documents(all_loaded_docs)
121
 
122
- # Small + high-quality general embedding
123
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
124
 
 
125
  persist_directory = "docs/chroma/"
126
  shutil.rmtree(persist_directory, ignore_errors=True)
127
 
@@ -133,12 +155,10 @@ vectordb = Chroma.from_documents(
133
 
134
  retriever = vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr")
135
 
136
- """
137
- =========================================================
138
- 3) PROMPT & PARSER
139
- =========================================================
140
- """
141
 
 
 
 
142
  class FinalAnswer(BaseModel):
143
  question: str = Field(description="User question")
144
  answer: str = Field(description="Direct answer")
@@ -166,16 +186,10 @@ qa_prompt = ChatPromptTemplate.from_template(
166
  partial_variables={"format_instructions": parser.get_format_instructions()},
167
  )
168
 
169
- """
170
- =========================================================
171
- 4) LLM SETUP (no router 404s)
172
- =========================================================
173
- """
174
- # IMPORTANT:
175
- # The old route "mistralai/Mixtral-8x7B-Instruct-v0.1" often 404s on the public HF router.
176
- # Use a router-available model OR your own paid Inference Endpoint.
177
 
178
- # ---- Option A (DEFAULT): public, free router model that works out-of-the-box
 
 
179
  DEFAULT_REPO = os.environ.get("HF_REPO_ID", "HuggingFaceH4/zephyr-7b-beta")
180
 
181
  llm = HuggingFaceEndpoint(
@@ -186,11 +200,11 @@ llm = HuggingFaceEndpoint(
186
  top_k=50,
187
  repetition_penalty=1.05,
188
  do_sample=True,
189
- # Set env: HUGGINGFACEHUB_API_TOKEN=hf_xxx
190
  )
191
 
192
- # ---- Option B (MIXTRAL): your paid Inference Endpoint
193
- # MIXTRAL_ENDPOINT_URL = os.environ.get("HF_ENDPOINT_URL") # e.g. https://xyz.aws.endpoints.huggingface.cloud
194
  # if MIXTRAL_ENDPOINT_URL:
195
  # llm = HuggingFaceEndpoint(
196
  # endpoint_url=MIXTRAL_ENDPOINT_URL,
@@ -202,12 +216,10 @@ llm = HuggingFaceEndpoint(
202
  # do_sample=True,
203
  # )
204
 
205
- """
206
- =========================================================
207
- 5) CHAIN (with safer memory + error handling)
208
- =========================================================
209
- """
210
 
 
 
 
211
  memory = ConversationBufferMemory(
212
  memory_key="chat_history",
213
  return_messages=True,
@@ -219,7 +231,7 @@ qa_chain = ConversationalRetrievalChain.from_llm(
219
  memory=memory,
220
  verbose=True,
221
  combine_docs_chain_kwargs={"prompt": qa_prompt},
222
- get_chat_history=lambda h: h, # memory already returns messages
223
  rephrase_question=False,
224
  output_key="output",
225
  )
@@ -227,18 +239,18 @@ qa_chain = ConversationalRetrievalChain.from_llm(
227
 
228
  def chat_interface(question, history):
229
  """
230
- Processes the user's question through the qa_chain,
231
- and robustly parses the JSON output per schema.
232
  """
233
  try:
234
  result = qa_chain.invoke({"question": question})
235
  raw = result.get("output", "").strip()
236
 
237
- # Try strict JSON first
238
  try:
239
  payload = json.loads(raw)
240
  except json.JSONDecodeError:
241
- # If the model returned extra text around JSON, try to extract the first JSON object
242
  start = raw.find("{")
243
  end = raw.rfind("}")
244
  if start != -1 and end != -1 and end > start:
@@ -249,12 +261,10 @@ def chat_interface(question, history):
249
  else:
250
  payload = {"question": question, "answer": raw}
251
 
252
- # Enforce schema
253
- answer_out = payload.get("answer", raw)
254
- return answer_out
255
 
256
  except Exception as e:
257
- # Fallback: return a friendly error + no crash
258
  return (
259
  "Lo siento, tuve un problema procesando tu pregunta. "
260
  "Intenta de nuevo en un momento o formula la consulta de otra manera.\n\n"
@@ -267,12 +277,10 @@ chatbot_gradio_app = gr.ChatInterface(
267
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
268
  )
269
 
270
- """
271
- =========================================================
272
- 6) BANNER / WELCOME TAB
273
- =========================================================
274
- """
275
 
 
 
 
276
  banner_tab_content = """
277
  <div style="background-color: #d3e3c3; text-align: center; padding: 20px; display: flex; flex-direction: column; align-items: center;">
278
  <img src="https://huggingface.co/spaces/ALVHB95/TFM_DataScience_APP/resolve/main/front_4.jpg" alt="Banner Image" style="width: 50%; max-width: 500px; margin: 0 auto;">
@@ -297,21 +305,18 @@ banner_tab_content = """
297
 
298
  banner_tab = gr.Markdown(banner_tab_content)
299
 
300
- """
301
- =========================================================
302
- 7) GRADIO FINAL APP: TABS
303
- =========================================================
304
- """
305
 
 
 
 
306
  app = gr.TabbedInterface(
307
  [banner_tab, image_gradio_app, chatbot_gradio_app],
308
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
309
  theme=theme,
310
  )
311
 
312
- # Enable queue() for concurrency and launch the Gradio app
313
  app.queue()
314
- # Tip: set GRADIO_SHARE=true in env if you want a public link
315
  app.launch(
316
  server_name="0.0.0.0",
317
  server_port=7860,
 
1
  """
2
  =========================================================
3
+ app.py — Green Greta (Gradio + HF + LangChain v0.2 + Keras 3)
4
+ - Keras 3: load SavedModel via keras.layers.TFSMLayer (not load_model)
5
+ - LLM: HuggingFaceEndpoint with router-friendly Zephyr by default
6
+ - LangChain v0.2 import layout (core/community/text-splitters)
7
+ - Robust JSON parsing for schema-shaped output
8
+ - EfficientNet input size fix (224x224)
9
+ - Gradio binds to 0.0.0.0:7860 (Docker-friendly)
 
10
  =========================================================
11
  """
12
 
13
+ # =========================
14
+ # Imports (grouped together)
15
+ # =========================
16
  import os
17
  import json
18
  import shutil
19
 
20
+ # UI / web
21
  import gradio as gr
22
+
23
+ # TensorFlow / Keras / image
24
  import tensorflow as tf
25
+ from tensorflow import keras
26
  from PIL import Image
27
 
28
+ # Networking / retry
29
+ import tenacity
30
  from fake_useragent import UserAgent
31
 
32
+ # LangChain v0.2 family
33
  from langchain_text_splitters import RecursiveCharacterTextSplitter
34
  from langchain_core.prompts import ChatPromptTemplate
35
  from langchain_core.output_parsers import PydanticOutputParser
 
40
  from langchain.chains import ConversationalRetrievalChain
41
  from langchain.memory import ConversationBufferMemory
42
 
43
+ # Pydantic (for typed schema in prompt)
44
  from pydantic.v1 import BaseModel, Field
45
 
46
+ # Hugging Face Hub helpers
47
+ from huggingface_hub import snapshot_download
48
+
49
+ # Local theming + URLs list
50
  import theme
51
+ from url_list import URLS
52
+
53
+
54
+ # =========================
55
+ # Theme instance
56
+ # =========================
57
  theme = theme.Theme()
58
 
 
 
59
 
60
+ # =========================================================
61
+ # 1) IMAGE CLASSIFICATION MODEL SETUP (Keras 3-compatible)
62
+ # =========================================================
63
+ # The HF repo is a TensorFlow SavedModel; with Keras 3 we must use TFSMLayer.
64
+ MODEL_REPO = "rocioadlc/efficientnetB0_trash"
65
+ MODEL_SERVING_SIGNATURE = "serving_default" # adjust if your repo uses another signature
66
 
67
+ # Download SavedModel locally
68
+ model_dir = snapshot_download(MODEL_REPO)
69
 
70
+ # Wrap SavedModel as a Keras layer
71
+ model1 = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
72
+
73
+ # Class labels
74
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
75
 
76
 
77
  def predict_image(input_image: Image.Image):
78
  """
79
  Resize the user-uploaded image and preprocess it for EfficientNetB0.
80
+ Works with a TFSMLayer (SavedModel) that returns a dict of tensors.
81
  """
82
+ img = input_image.convert("RGB").resize((224, 224)) # EfficientNetB0 expects 224x224
83
+ image_array = tf.keras.preprocessing.image.img_to_array(img)
 
 
84
  image_array = tf.keras.applications.efficientnet.preprocess_input(image_array)
85
+ image_array = tf.expand_dims(image_array, 0) # [1, 224, 224, 3]
86
+
87
+ # TFSMLayer returns a dict for SavedModel; select the first output
88
+ outputs = model1(image_array)
89
+ if isinstance(outputs, dict) and outputs:
90
+ first_key = next(iter(outputs.keys()))
91
+ preds = outputs[first_key]
92
+ else:
93
+ preds = outputs
94
+
95
+ preds_np = preds.numpy() if hasattr(preds, "numpy") else preds
96
+ probs = preds_np[0].tolist()
97
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
98
 
99
 
 
105
  theme=theme,
106
  )
107
 
108
+
109
+ # ============================================
110
+ # 2) KNOWLEDGE LOADING (RAG: loader + splitter)
111
+ # ============================================
 
 
112
  user_agent = UserAgent().random
113
  header_template = {"User-Agent": user_agent}
114
 
 
126
  docs = load_url(link)
127
  all_docs.extend(docs)
128
  except Exception as e:
 
129
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
130
  return all_docs
131
 
 
140
 
141
  docs = text_splitter.split_documents(all_loaded_docs)
142
 
143
+ # Embeddings
144
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
145
 
146
+ # Vector store (Chroma)
147
  persist_directory = "docs/chroma/"
148
  shutil.rmtree(persist_directory, ignore_errors=True)
149
 
 
155
 
156
  retriever = vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr")
157
 
 
 
 
 
 
158
 
159
+ # ======================================
160
+ # 3) PROMPT & SCHEMA OUTPUT PARSING
161
+ # ======================================
162
  class FinalAnswer(BaseModel):
163
  question: str = Field(description="User question")
164
  answer: str = Field(description="Direct answer")
 
186
  partial_variables={"format_instructions": parser.get_format_instructions()},
187
  )
188
 
 
 
 
 
 
 
 
 
189
 
190
+ # =============================
191
+ # 4) LLM (router-friendly HF)
192
+ # =============================
193
  DEFAULT_REPO = os.environ.get("HF_REPO_ID", "HuggingFaceH4/zephyr-7b-beta")
194
 
195
  llm = HuggingFaceEndpoint(
 
200
  top_k=50,
201
  repetition_penalty=1.05,
202
  do_sample=True,
203
+ # Requires env: HUGGINGFACEHUB_API_TOKEN=hf_xxx
204
  )
205
 
206
+ # If you deploy a paid Inference Endpoint (e.g., for Mixtral), use:
207
+ # MIXTRAL_ENDPOINT_URL = os.environ.get("HF_ENDPOINT_URL")
208
  # if MIXTRAL_ENDPOINT_URL:
209
  # llm = HuggingFaceEndpoint(
210
  # endpoint_url=MIXTRAL_ENDPOINT_URL,
 
216
  # do_sample=True,
217
  # )
218
 
 
 
 
 
 
219
 
220
+ # ===========================================
221
+ # 5) Chain (memory + robust JSON extraction)
222
+ # ===========================================
223
  memory = ConversationBufferMemory(
224
  memory_key="chat_history",
225
  return_messages=True,
 
231
  memory=memory,
232
  verbose=True,
233
  combine_docs_chain_kwargs={"prompt": qa_prompt},
234
+ get_chat_history=lambda h: h,
235
  rephrase_question=False,
236
  output_key="output",
237
  )
 
239
 
240
  def chat_interface(question, history):
241
  """
242
+ Run the QA chain and return the 'answer' field from a JSON payload.
243
+ Falls back safely if the LLM returns non-JSON text.
244
  """
245
  try:
246
  result = qa_chain.invoke({"question": question})
247
  raw = result.get("output", "").strip()
248
 
249
+ # Strict JSON first
250
  try:
251
  payload = json.loads(raw)
252
  except json.JSONDecodeError:
253
+ # Try extracting first {...} block
254
  start = raw.find("{")
255
  end = raw.rfind("}")
256
  if start != -1 and end != -1 and end > start:
 
261
  else:
262
  payload = {"question": question, "answer": raw}
263
 
264
+ # Return the schema field
265
+ return payload.get("answer", raw)
 
266
 
267
  except Exception as e:
 
268
  return (
269
  "Lo siento, tuve un problema procesando tu pregunta. "
270
  "Intenta de nuevo en un momento o formula la consulta de otra manera.\n\n"
 
277
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
278
  )
279
 
 
 
 
 
 
280
 
281
+ # ============================
282
+ # 6) Banner / Welcome content
283
+ # ============================
284
  banner_tab_content = """
285
  <div style="background-color: #d3e3c3; text-align: center; padding: 20px; display: flex; flex-direction: column; align-items: center;">
286
  <img src="https://huggingface.co/spaces/ALVHB95/TFM_DataScience_APP/resolve/main/front_4.jpg" alt="Banner Image" style="width: 50%; max-width: 500px; margin: 0 auto;">
 
305
 
306
  banner_tab = gr.Markdown(banner_tab_content)
307
 
 
 
 
 
 
308
 
309
+ # ============================
310
+ # 7) Gradio app (tabs + run)
311
+ # ============================
312
  app = gr.TabbedInterface(
313
  [banner_tab, image_gradio_app, chatbot_gradio_app],
314
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
315
  theme=theme,
316
  )
317
 
318
+ # Concurrency queue + launch (Docker-friendly binding)
319
  app.queue()
 
320
  app.launch(
321
  server_name="0.0.0.0",
322
  server_port=7860,