ALVHB95 commited on
Commit
03511a5
·
1 Parent(s): b87c428
Files changed (2) hide show
  1. app.py +38 -39
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,10 +1,9 @@
1
  """
2
  =========================================================
3
- app.py — Green Greta (Gradio + TF/Keras 3 + LangChain v0.2)
4
  =========================================================
5
  """
6
 
7
- # ========== Imports ==========
8
  import os
9
  import json
10
  import shutil
@@ -17,16 +16,16 @@ from PIL import Image
17
  import tenacity
18
  from fake_useragent import UserAgent
19
 
20
- # LangChain v0.2 family
21
- from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
22
  from langchain_core.prompts import ChatPromptTemplate
23
  from langchain_core.output_parsers import PydanticOutputParser
24
- from langchain_community.document_loaders import WebBaseLoader
25
- from langchain_community.vectorstores import Chroma
26
  from langchain.chains import ConversationalRetrievalChain
27
  from langchain.memory import ConversationBufferMemory
 
 
28
 
29
- # Embeddings (use community; switch to langchain-huggingface later if desired)
30
  try:
31
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
32
  except ImportError:
@@ -36,45 +35,42 @@ except ImportError:
36
  from langchain.retrievers import ContextualCompressionRetriever
37
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
38
 
39
- # Pydantic schema
40
  from pydantic.v1 import BaseModel, Field
41
 
42
- # HF Hub for SavedModel download
43
  from huggingface_hub import snapshot_download
44
 
45
- # Transformers local pipeline (no token needed)
46
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
47
  from langchain_community.llms import HuggingFacePipeline
48
 
49
- # Local theme + URLs
50
  import theme
51
  from url_list import URLS
52
 
53
-
54
- # ========== Theme ==========
55
  theme = theme.Theme()
56
 
57
 
58
  # =========================================================
59
- # 1) IMAGE CLASSIFICATION (Keras 3-compatible SavedModel)
60
  # =========================================================
61
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
62
- MODEL_SERVING_SIGNATURE = "serving_default" # adjust if your model uses a different signature
63
 
64
- # Download the SavedModel once and wrap with Keras TFSMLayer
65
  model_dir = snapshot_download(MODEL_REPO)
66
- model1 = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
67
 
68
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
69
 
70
  def predict_image(input_image: Image.Image):
71
- """Preprocess to 224x224 EfficientNet input and run inference."""
72
  img = input_image.convert("RGB").resize((224, 224))
73
  x = tf.keras.preprocessing.image.img_to_array(img)
74
  x = tf.keras.applications.efficientnet.preprocess_input(x)
75
- x = tf.expand_dims(x, 0) # batch
76
 
77
- outputs = model1(x)
78
  if isinstance(outputs, dict) and outputs:
79
  preds = outputs[next(iter(outputs))]
80
  else:
@@ -116,13 +112,13 @@ def safe_load_all_urls(urls):
116
 
117
  all_loaded_docs = safe_load_all_urls(URLS)
118
 
119
- # Smaller base chunks to help keep prompts short
120
- text_splitter = RecursiveCharacterTextSplitter(
121
  chunk_size=700,
122
  chunk_overlap=80,
123
  length_function=len,
124
  )
125
- docs = text_splitter.split_documents(all_loaded_docs)
126
 
127
  # Embeddings
128
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
@@ -138,20 +134,27 @@ vectordb = Chroma.from_documents(
138
  )
139
 
140
  # Base retriever
141
- base_retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
142
-
143
- # Hard-cap tokens in retrieved docs (~200 tokens per slice)
144
- token_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=30)
145
- compressor = DocumentCompressorPipeline(transformers=[token_splitter])
146
 
 
 
 
 
 
 
 
 
 
 
 
147
  compression_retriever = ContextualCompressionRetriever(
148
- base_retriever=base_retriever,
149
  base_compressor=compressor,
150
  )
151
 
152
 
153
  # ======================================
154
- # 3) PROMPT & SCHEMA OUTPUT PARSING
155
  # ======================================
156
  class FinalAnswer(BaseModel):
157
  question: str = Field(description="User question")
@@ -174,7 +177,7 @@ qa_prompt = ChatPromptTemplate.from_template(
174
 
175
 
176
  # =============================
177
- # 4) LLM (token-free local model)
178
  # =============================
179
  LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
180
  tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
@@ -185,7 +188,7 @@ gen = pipeline(
185
  model=mdl,
186
  tokenizer=tok,
187
  max_new_tokens=512,
188
- do_sample=False, # deterministic; better for JSON adherence
189
  )
190
  llm = HuggingFacePipeline(pipeline=gen)
191
 
@@ -200,7 +203,7 @@ memory = ConversationBufferMemory(
200
 
201
  qa_chain = ConversationalRetrievalChain.from_llm(
202
  llm=llm,
203
- retriever=compression_retriever, # <= compressed retriever to avoid 512-token overflows
204
  memory=memory,
205
  verbose=True,
206
  combine_docs_chain_kwargs={"prompt": qa_prompt},
@@ -210,7 +213,7 @@ qa_chain = ConversationalRetrievalChain.from_llm(
210
  )
211
 
212
  def _safe_json_extract(raw: str, question: str) -> dict:
213
- """Try strict JSON; otherwise extract first {...} block; fallback to plain text."""
214
  raw = (raw or "").strip()
215
  try:
216
  return json.loads(raw)
@@ -279,8 +282,4 @@ app = gr.TabbedInterface(
279
  )
280
 
281
  app.queue()
282
- app.launch(
283
- server_name="0.0.0.0",
284
- server_port=7860,
285
- share=os.environ.get("GRADIO_SHARE", "false").lower() == "true",
286
- )
 
1
  """
2
  =========================================================
3
+ app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2)
4
  =========================================================
5
  """
6
 
 
7
  import os
8
  import json
9
  import shutil
 
16
  import tenacity
17
  from fake_useragent import UserAgent
18
 
19
+ # --- LangChain v0.2 family ---
20
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
22
  from langchain_core.output_parsers import PydanticOutputParser
 
 
23
  from langchain.chains import ConversationalRetrievalChain
24
  from langchain.memory import ConversationBufferMemory
25
+ from langchain_community.document_loaders import WebBaseLoader
26
+ from langchain_community.vectorstores import Chroma
27
 
28
+ # Embeddings (prefer langchain-huggingface if installed; fallback to community)
29
  try:
30
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
31
  except ImportError:
 
35
  from langchain.retrievers import ContextualCompressionRetriever
36
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
37
 
 
38
  from pydantic.v1 import BaseModel, Field
39
 
40
+ # HF Hub for downloading the SavedModel once
41
  from huggingface_hub import snapshot_download
42
 
43
+ # Local transformers pipeline (no API token required)
44
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
45
  from langchain_community.llms import HuggingFacePipeline
46
 
47
+ # Theming + URL list
48
  import theme
49
  from url_list import URLS
50
 
 
 
51
  theme = theme.Theme()
52
 
53
 
54
  # =========================================================
55
+ # 1) IMAGE CLASSIFICATION Keras 3-safe SavedModel loading
56
  # =========================================================
57
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
58
+ MODEL_SERVING_SIGNATURE = "serving_default" # adjust if the model exposes a different endpoint
59
 
60
+ # Download the model snapshot and wrap it via TFSMLayer (Keras 3 compatible)
61
  model_dir = snapshot_download(MODEL_REPO)
62
+ image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
63
 
64
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
65
 
66
  def predict_image(input_image: Image.Image):
67
+ """Preprocess to EfficientNetB0 input (224x224) and run inference."""
68
  img = input_image.convert("RGB").resize((224, 224))
69
  x = tf.keras.preprocessing.image.img_to_array(img)
70
  x = tf.keras.applications.efficientnet.preprocess_input(x)
71
+ x = tf.expand_dims(x, 0)
72
 
73
+ outputs = image_model(x)
74
  if isinstance(outputs, dict) and outputs:
75
  preds = outputs[next(iter(outputs))]
76
  else:
 
112
 
113
  all_loaded_docs = safe_load_all_urls(URLS)
114
 
115
+ # Smaller base chunks so downstream compression has less work
116
+ base_splitter = RecursiveCharacterTextSplitter(
117
  chunk_size=700,
118
  chunk_overlap=80,
119
  length_function=len,
120
  )
121
+ docs = base_splitter.split_documents(all_loaded_docs)
122
 
123
  # Embeddings
124
  embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
 
134
  )
135
 
136
  # Base retriever
137
+ retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
 
 
 
 
138
 
139
+ # --- Context compression to keep inputs under FLAN-T5 512-token limit ---
140
+ # Prefer token-aware splitter; fall back to char splitter if `tiktoken` isn't installed.
141
+ try:
142
+ from langchain_text_splitters import TokenTextSplitter
143
+ splitter_for_compression = TokenTextSplitter(chunk_size=200, chunk_overlap=30) # needs `tiktoken`
144
+ except Exception:
145
+ # Fallback that doesn't require tiktoken
146
+ from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
147
+ splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
148
+
149
+ compressor = DocumentCompressorPipeline(transformers=[splitter_for_compression])
150
  compression_retriever = ContextualCompressionRetriever(
151
+ base_retriever=retriever,
152
  base_compressor=compressor,
153
  )
154
 
155
 
156
  # ======================================
157
+ # 3) PROMPT & Pydantic schema parsing
158
  # ======================================
159
  class FinalAnswer(BaseModel):
160
  question: str = Field(description="User question")
 
177
 
178
 
179
  # =============================
180
+ # 4) LLM — local, token-free
181
  # =============================
182
  LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
183
  tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
 
188
  model=mdl,
189
  tokenizer=tok,
190
  max_new_tokens=512,
191
+ do_sample=False, # deterministic; helps JSON adherence
192
  )
193
  llm = HuggingFacePipeline(pipeline=gen)
194
 
 
203
 
204
  qa_chain = ConversationalRetrievalChain.from_llm(
205
  llm=llm,
206
+ retriever=compression_retriever,
207
  memory=memory,
208
  verbose=True,
209
  combine_docs_chain_kwargs={"prompt": qa_prompt},
 
213
  )
214
 
215
  def _safe_json_extract(raw: str, question: str) -> dict:
216
+ """Try strict JSON; otherwise extract first {...}; fallback to plain text."""
217
  raw = (raw or "").strip()
218
  try:
219
  return json.loads(raw)
 
282
  )
283
 
284
  app.queue()
285
+ app.launch()
 
 
 
 
requirements.txt CHANGED
@@ -28,3 +28,5 @@ fastapi==0.115.0
28
  starlette==0.38.2
29
  pydantic==2.8.2
30
  pydantic-core==2.20.1
 
 
 
28
  starlette==0.38.2
29
  pydantic==2.8.2
30
  pydantic-core==2.20.1
31
+
32
+ tiktoken>=0.5.2