ALVHB95 commited on
Commit
b87c428
·
1 Parent(s): 2b15ba2
Files changed (1) hide show
  1. app.py +68 -101
app.py CHANGED
@@ -1,36 +1,24 @@
1
  """
2
  =========================================================
3
  app.py — Green Greta (Gradio + TF/Keras 3 + LangChain v0.2)
4
- - Image model: load TF SavedModel via keras.layers.TFSMLayer (Keras 3 safe)
5
- - LLM: local transformers pipeline (no HF API token required)
6
- - LangChain v0.2 imports (text_splitters/core/community)
7
- - Robust JSON parsing for schema-shaped output
8
- - EfficientNet input size fix (224x224)
9
- - Gradio binds to 0.0.0.0:7860 (Docker-friendly)
10
  =========================================================
11
  """
12
 
13
- # =========================
14
- # Imports (grouped together)
15
- # =========================
16
  import os
17
  import json
18
  import shutil
19
 
20
- # UI / web
21
  import gradio as gr
22
-
23
- # TensorFlow / Keras / image
24
  import tensorflow as tf
25
  from tensorflow import keras
26
  from PIL import Image
27
 
28
- # Networking / retry
29
  import tenacity
30
  from fake_useragent import UserAgent
31
 
32
  # LangChain v0.2 family
33
- from langchain_text_splitters import RecursiveCharacterTextSplitter
34
  from langchain_core.prompts import ChatPromptTemplate
35
  from langchain_core.output_parsers import PydanticOutputParser
36
  from langchain_community.document_loaders import WebBaseLoader
@@ -38,67 +26,64 @@ from langchain_community.vectorstores import Chroma
38
  from langchain.chains import ConversationalRetrievalChain
39
  from langchain.memory import ConversationBufferMemory
40
 
41
- # Embeddings (community version works; you can switch to langchain-huggingface later)
42
- from langchain_community.embeddings import HuggingFaceEmbeddings
43
- # If you prefer to silence deprecation warnings in the future:
44
- # from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
 
 
 
 
 
45
 
46
- # Pydantic for schema in prompt
47
  from pydantic.v1 import BaseModel, Field
48
 
49
- # Hugging Face Hub helper for SavedModel
50
  from huggingface_hub import snapshot_download
51
 
52
- # Local theming + URLs list
 
 
 
 
53
  import theme
54
  from url_list import URLS
55
 
56
 
57
- # =========================
58
- # Theme instance
59
- # =========================
60
  theme = theme.Theme()
61
 
62
 
63
  # =========================================================
64
- # 1) IMAGE CLASSIFICATION MODEL SETUP (Keras 3-compatible)
65
  # =========================================================
66
- # The HF repo is a TensorFlow SavedModel; with Keras 3 we must use TFSMLayer.
67
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
68
- MODEL_SERVING_SIGNATURE = "serving_default" # adjust if your repo uses another signature
69
 
70
- # Download SavedModel locally
71
  model_dir = snapshot_download(MODEL_REPO)
72
-
73
- # Wrap SavedModel as a Keras layer
74
  model1 = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
75
 
76
- # Class labels
77
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
78
 
79
-
80
  def predict_image(input_image: Image.Image):
81
- """
82
- Resize the user-uploaded image and preprocess it for EfficientNetB0.
83
- Works with a TFSMLayer (SavedModel) that may return a dict of tensors.
84
- """
85
- img = input_image.convert("RGB").resize((224, 224)) # EfficientNetB0 expects 224x224
86
  x = tf.keras.preprocessing.image.img_to_array(img)
87
  x = tf.keras.applications.efficientnet.preprocess_input(x)
88
- x = tf.expand_dims(x, 0) # [1, 224, 224, 3]
89
 
90
  outputs = model1(x)
91
  if isinstance(outputs, dict) and outputs:
92
- key = next(iter(outputs))
93
- preds = outputs[key]
94
  else:
95
  preds = outputs
96
 
97
- preds_np = preds.numpy() if hasattr(preds, "numpy") else preds
98
- probs = preds_np[0].tolist()
99
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
100
 
101
-
102
  image_gradio_app = gr.Interface(
103
  fn=predict_image,
104
  inputs=gr.Image(label="Image", sources=["upload", "webcam"], type="pil"),
@@ -114,13 +99,11 @@ image_gradio_app = gr.Interface(
114
  user_agent = UserAgent().random
115
  header_template = {"User-Agent": user_agent}
116
 
117
-
118
  @tenacity.retry(wait=tenacity.wait_fixed(3), stop=tenacity.stop_after_attempt(3), reraise=True)
119
  def load_url(url: str):
120
  loader = WebBaseLoader(web_paths=[url], header_template=header_template)
121
  return loader.load()
122
 
123
-
124
  def safe_load_all_urls(urls):
125
  all_docs = []
126
  for link in urls:
@@ -131,21 +114,20 @@ def safe_load_all_urls(urls):
131
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
132
  return all_docs
133
 
134
-
135
  all_loaded_docs = safe_load_all_urls(URLS)
136
 
 
137
  text_splitter = RecursiveCharacterTextSplitter(
138
- chunk_size=1024,
139
- chunk_overlap=150,
140
  length_function=len,
141
  )
142
-
143
  docs = text_splitter.split_documents(all_loaded_docs)
144
 
145
  # Embeddings
146
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
147
 
148
- # Vector store (Chroma)
149
  persist_directory = "docs/chroma/"
150
  shutil.rmtree(persist_directory, ignore_errors=True)
151
 
@@ -155,7 +137,17 @@ vectordb = Chroma.from_documents(
155
  persist_directory=persist_directory,
156
  )
157
 
158
- retriever = vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr")
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  # ======================================
@@ -165,23 +157,15 @@ class FinalAnswer(BaseModel):
165
  question: str = Field(description="User question")
166
  answer: str = Field(description="Direct answer")
167
 
168
-
169
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
170
 
171
  SYSTEM_TEMPLATE = (
172
- """
173
- Your name is Greta and you are a recycling chatbot with the objective to answer questions from user in English or Spanish /
174
- Has sido diseñado y creado por el Grupo 1 del Máster en Data Science & Big Data de la promoción 2023/2024 de la Universidad Complutense de Madrid. Este grupo está formado por Rocío, María Guillermo, Alejandra, Paloma y Álvaro /
175
- Use the following pieces of context to answer the question /
176
- If the question is English answer in English /
177
- If the question is Spanish answer in Spanish /
178
- Do not mention the word context when you answer a question /
179
- Answer the question fully and provide as much relevant detail as possible. Do not cut your response short /
180
- Context: {context}
181
- User: {question}
182
- {format_instructions}
183
- """
184
- ).strip()
185
 
186
  qa_prompt = ChatPromptTemplate.from_template(
187
  SYSTEM_TEMPLATE,
@@ -192,12 +176,7 @@ qa_prompt = ChatPromptTemplate.from_template(
192
  # =============================
193
  # 4) LLM (token-free local model)
194
  # =============================
195
- # Avoids HF Endpoint auth + deprecated .post path. Good defaults for CPU.
196
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
197
- from langchain_community.llms import HuggingFacePipeline
198
-
199
  LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
200
-
201
  tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
202
  mdl = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_ID)
203
 
@@ -206,9 +185,8 @@ gen = pipeline(
206
  model=mdl,
207
  tokenizer=tok,
208
  max_new_tokens=512,
209
- do_sample=False,
210
  )
211
-
212
  llm = HuggingFacePipeline(pipeline=gen)
213
 
214
 
@@ -222,7 +200,7 @@ memory = ConversationBufferMemory(
222
 
223
  qa_chain = ConversationalRetrievalChain.from_llm(
224
  llm=llm,
225
- retriever=retriever,
226
  memory=memory,
227
  verbose=True,
228
  combine_docs_chain_kwargs={"prompt": qa_prompt},
@@ -231,34 +209,26 @@ qa_chain = ConversationalRetrievalChain.from_llm(
231
  output_key="output",
232
  )
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  def chat_interface(question, history):
236
- """
237
- Run the QA chain and return the 'answer' field from a JSON payload.
238
- Falls back safely if the LLM returns non-JSON text.
239
- """
240
  try:
241
  result = qa_chain.invoke({"question": question})
242
- raw = result.get("output", "").strip()
243
-
244
- # Strict JSON first
245
- try:
246
- payload = json.loads(raw)
247
- except json.JSONDecodeError:
248
- # Try extracting first {...} block
249
- start = raw.find("{")
250
- end = raw.rfind("}")
251
- if start != -1 and end != -1 and end > start:
252
- try:
253
- payload = json.loads(raw[start : end + 1])
254
- except json.JSONDecodeError:
255
- payload = {"question": question, "answer": raw}
256
- else:
257
- payload = {"question": question, "answer": raw}
258
-
259
- # Return the schema field
260
- return payload.get("answer", raw)
261
-
262
  except Exception as e:
263
  return (
264
  "Lo siento, tuve un problema procesando tu pregunta. "
@@ -266,7 +236,6 @@ def chat_interface(question, history):
266
  f"Detalle técnico: {e}"
267
  )
268
 
269
-
270
  chatbot_gradio_app = gr.ChatInterface(
271
  fn=chat_interface,
272
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
@@ -284,7 +253,7 @@ banner_tab_content = """
284
  <p style="font-size: 16px; color: #4e6339; text-align: justify;">Nuestra plataforma combina la potencia de la inteligencia artificial con la comodidad de un chatbot para brindarte respuestas rápidas y precisas sobre qué objetos son reciclables y cómo hacerlo de la manera más eficiente.</p>
285
  <p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">¿Cómo usarlo?</span></strong></p>
286
  <ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;">
287
- <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Ve a la pestaña Greta Image Classification y simplemente carga una foto del objeto que quieras reciclar, y nuestro modelo de identificará de qué se trata🕵️‍♂️ para que puedas desecharlo adecuadamente.</li>
288
  <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> ¿Tienes preguntas sobre reciclaje, materiales específicos o prácticas sostenibles? ¡Pregunta a nuestro chatbot en la pestaña Green Greta Chat!📝 Está aquí para responder todas tus preguntas y ayudarte a tomar decisiones más informadas sobre tu reciclaje.</li>
289
  </ul>
290
  <h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">Welcome to our image classifier and chatbot for smarter recycling!♻️</h1>
@@ -297,7 +266,6 @@ banner_tab_content = """
297
  </ul>
298
  </div>
299
  """
300
-
301
  banner_tab = gr.Markdown(banner_tab_content)
302
 
303
 
@@ -310,7 +278,6 @@ app = gr.TabbedInterface(
310
  theme=theme,
311
  )
312
 
313
- # Concurrency queue + launch (Docker-friendly binding)
314
  app.queue()
315
  app.launch(
316
  server_name="0.0.0.0",
 
1
  """
2
  =========================================================
3
  app.py — Green Greta (Gradio + TF/Keras 3 + LangChain v0.2)
 
 
 
 
 
 
4
  =========================================================
5
  """
6
 
7
+ # ========== Imports ==========
 
 
8
  import os
9
  import json
10
  import shutil
11
 
 
12
  import gradio as gr
 
 
13
  import tensorflow as tf
14
  from tensorflow import keras
15
  from PIL import Image
16
 
 
17
  import tenacity
18
  from fake_useragent import UserAgent
19
 
20
  # LangChain v0.2 family
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
22
  from langchain_core.prompts import ChatPromptTemplate
23
  from langchain_core.output_parsers import PydanticOutputParser
24
  from langchain_community.document_loaders import WebBaseLoader
 
26
  from langchain.chains import ConversationalRetrievalChain
27
  from langchain.memory import ConversationBufferMemory
28
 
29
+ # Embeddings (use community; switch to langchain-huggingface later if desired)
30
+ try:
31
+ from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
32
+ except ImportError:
33
+ from langchain_community.embeddings import HuggingFaceEmbeddings
34
+
35
+ # Context compression (keeps inputs ≤ model limit)
36
+ from langchain.retrievers import ContextualCompressionRetriever
37
+ from langchain.retrievers.document_compressors import DocumentCompressorPipeline
38
 
39
+ # Pydantic schema
40
  from pydantic.v1 import BaseModel, Field
41
 
42
+ # HF Hub for SavedModel download
43
  from huggingface_hub import snapshot_download
44
 
45
+ # Transformers local pipeline (no token needed)
46
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
47
+ from langchain_community.llms import HuggingFacePipeline
48
+
49
+ # Local theme + URLs
50
  import theme
51
  from url_list import URLS
52
 
53
 
54
+ # ========== Theme ==========
 
 
55
  theme = theme.Theme()
56
 
57
 
58
  # =========================================================
59
+ # 1) IMAGE CLASSIFICATION (Keras 3-compatible SavedModel)
60
  # =========================================================
 
61
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
62
+ MODEL_SERVING_SIGNATURE = "serving_default" # adjust if your model uses a different signature
63
 
64
+ # Download the SavedModel once and wrap with Keras TFSMLayer
65
  model_dir = snapshot_download(MODEL_REPO)
 
 
66
  model1 = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
67
 
 
68
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
69
 
 
70
  def predict_image(input_image: Image.Image):
71
+ """Preprocess to 224x224 EfficientNet input and run inference."""
72
+ img = input_image.convert("RGB").resize((224, 224))
 
 
 
73
  x = tf.keras.preprocessing.image.img_to_array(img)
74
  x = tf.keras.applications.efficientnet.preprocess_input(x)
75
+ x = tf.expand_dims(x, 0) # batch
76
 
77
  outputs = model1(x)
78
  if isinstance(outputs, dict) and outputs:
79
+ preds = outputs[next(iter(outputs))]
 
80
  else:
81
  preds = outputs
82
 
83
+ arr = preds.numpy() if hasattr(preds, "numpy") else preds
84
+ probs = arr[0].tolist()
85
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
86
 
 
87
  image_gradio_app = gr.Interface(
88
  fn=predict_image,
89
  inputs=gr.Image(label="Image", sources=["upload", "webcam"], type="pil"),
 
99
  user_agent = UserAgent().random
100
  header_template = {"User-Agent": user_agent}
101
 
 
102
  @tenacity.retry(wait=tenacity.wait_fixed(3), stop=tenacity.stop_after_attempt(3), reraise=True)
103
  def load_url(url: str):
104
  loader = WebBaseLoader(web_paths=[url], header_template=header_template)
105
  return loader.load()
106
 
 
107
  def safe_load_all_urls(urls):
108
  all_docs = []
109
  for link in urls:
 
114
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
115
  return all_docs
116
 
 
117
  all_loaded_docs = safe_load_all_urls(URLS)
118
 
119
+ # Smaller base chunks to help keep prompts short
120
  text_splitter = RecursiveCharacterTextSplitter(
121
+ chunk_size=700,
122
+ chunk_overlap=80,
123
  length_function=len,
124
  )
 
125
  docs = text_splitter.split_documents(all_loaded_docs)
126
 
127
  # Embeddings
128
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
129
 
130
+ # Vector store
131
  persist_directory = "docs/chroma/"
132
  shutil.rmtree(persist_directory, ignore_errors=True)
133
 
 
137
  persist_directory=persist_directory,
138
  )
139
 
140
+ # Base retriever
141
+ base_retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
142
+
143
+ # Hard-cap tokens in retrieved docs (~200 tokens per slice)
144
+ token_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=30)
145
+ compressor = DocumentCompressorPipeline(transformers=[token_splitter])
146
+
147
+ compression_retriever = ContextualCompressionRetriever(
148
+ base_retriever=base_retriever,
149
+ base_compressor=compressor,
150
+ )
151
 
152
 
153
  # ======================================
 
157
  question: str = Field(description="User question")
158
  answer: str = Field(description="Direct answer")
159
 
 
160
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
161
 
162
  SYSTEM_TEMPLATE = (
163
+ "You are Greta, a bilingual (EN/ES) recycling assistant. "
164
+ "Answer fully using the snippets below. Do not mention 'context'.\n\n"
165
+ "Context:\n{context}\n\n"
166
+ "User: {question}\n"
167
+ "{format_instructions}"
168
+ )
 
 
 
 
 
 
 
169
 
170
  qa_prompt = ChatPromptTemplate.from_template(
171
  SYSTEM_TEMPLATE,
 
176
  # =============================
177
  # 4) LLM (token-free local model)
178
  # =============================
 
 
 
 
179
  LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
 
180
  tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
181
  mdl = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_ID)
182
 
 
185
  model=mdl,
186
  tokenizer=tok,
187
  max_new_tokens=512,
188
+ do_sample=False, # deterministic; better for JSON adherence
189
  )
 
190
  llm = HuggingFacePipeline(pipeline=gen)
191
 
192
 
 
200
 
201
  qa_chain = ConversationalRetrievalChain.from_llm(
202
  llm=llm,
203
+ retriever=compression_retriever, # <= compressed retriever to avoid 512-token overflows
204
  memory=memory,
205
  verbose=True,
206
  combine_docs_chain_kwargs={"prompt": qa_prompt},
 
209
  output_key="output",
210
  )
211
 
212
+ def _safe_json_extract(raw: str, question: str) -> dict:
213
+ """Try strict JSON; otherwise extract first {...} block; fallback to plain text."""
214
+ raw = (raw or "").strip()
215
+ try:
216
+ return json.loads(raw)
217
+ except json.JSONDecodeError:
218
+ start = raw.find("{")
219
+ end = raw.rfind("}")
220
+ if start != -1 and end != -1 and end > start:
221
+ try:
222
+ return json.loads(raw[start : end + 1])
223
+ except json.JSONDecodeError:
224
+ pass
225
+ return {"question": question, "answer": raw or "No answer produced."}
226
 
227
  def chat_interface(question, history):
 
 
 
 
228
  try:
229
  result = qa_chain.invoke({"question": question})
230
+ payload = _safe_json_extract(result.get("output", ""), question)
231
+ return payload.get("answer", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  except Exception as e:
233
  return (
234
  "Lo siento, tuve un problema procesando tu pregunta. "
 
236
  f"Detalle técnico: {e}"
237
  )
238
 
 
239
  chatbot_gradio_app = gr.ChatInterface(
240
  fn=chat_interface,
241
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
 
253
  <p style="font-size: 16px; color: #4e6339; text-align: justify;">Nuestra plataforma combina la potencia de la inteligencia artificial con la comodidad de un chatbot para brindarte respuestas rápidas y precisas sobre qué objetos son reciclables y cómo hacerlo de la manera más eficiente.</p>
254
  <p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">¿Cómo usarlo?</span></strong></p>
255
  <ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;">
256
+ <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Ve a la pestaña Greta Image Classification y simplemente carga una foto del objeto que quieras reciclar, y nuestro modelo identificará de qué se trata🕵️‍♂️ para que puedas desecharlo adecuadamente.</li>
257
  <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> ¿Tienes preguntas sobre reciclaje, materiales específicos o prácticas sostenibles? ¡Pregunta a nuestro chatbot en la pestaña Green Greta Chat!📝 Está aquí para responder todas tus preguntas y ayudarte a tomar decisiones más informadas sobre tu reciclaje.</li>
258
  </ul>
259
  <h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">Welcome to our image classifier and chatbot for smarter recycling!♻️</h1>
 
266
  </ul>
267
  </div>
268
  """
 
269
  banner_tab = gr.Markdown(banner_tab_content)
270
 
271
 
 
278
  theme=theme,
279
  )
280
 
 
281
  app.queue()
282
  app.launch(
283
  server_name="0.0.0.0",