Spaces:
Sleeping
Sleeping
new model
Browse files- app.py +38 -39
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
=========================================================
|
| 3 |
-
app.py — Green Greta (Gradio + TF/Keras 3 + LangChain v0.2)
|
| 4 |
=========================================================
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
# ========== Imports ==========
|
| 8 |
import os
|
| 9 |
import json
|
| 10 |
import shutil
|
|
@@ -17,16 +16,16 @@ from PIL import Image
|
|
| 17 |
import tenacity
|
| 18 |
from fake_useragent import UserAgent
|
| 19 |
|
| 20 |
-
# LangChain v0.2 family
|
| 21 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 22 |
from langchain_core.prompts import ChatPromptTemplate
|
| 23 |
from langchain_core.output_parsers import PydanticOutputParser
|
| 24 |
-
from langchain_community.document_loaders import WebBaseLoader
|
| 25 |
-
from langchain_community.vectorstores import Chroma
|
| 26 |
from langchain.chains import ConversationalRetrievalChain
|
| 27 |
from langchain.memory import ConversationBufferMemory
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
# Embeddings (
|
| 30 |
try:
|
| 31 |
from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
|
| 32 |
except ImportError:
|
|
@@ -36,45 +35,42 @@ except ImportError:
|
|
| 36 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 37 |
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
| 38 |
|
| 39 |
-
# Pydantic schema
|
| 40 |
from pydantic.v1 import BaseModel, Field
|
| 41 |
|
| 42 |
-
# HF Hub for SavedModel
|
| 43 |
from huggingface_hub import snapshot_download
|
| 44 |
|
| 45 |
-
#
|
| 46 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 47 |
from langchain_community.llms import HuggingFacePipeline
|
| 48 |
|
| 49 |
-
#
|
| 50 |
import theme
|
| 51 |
from url_list import URLS
|
| 52 |
|
| 53 |
-
|
| 54 |
-
# ========== Theme ==========
|
| 55 |
theme = theme.Theme()
|
| 56 |
|
| 57 |
|
| 58 |
# =========================================================
|
| 59 |
-
# 1) IMAGE CLASSIFICATION
|
| 60 |
# =========================================================
|
| 61 |
MODEL_REPO = "rocioadlc/efficientnetB0_trash"
|
| 62 |
-
MODEL_SERVING_SIGNATURE = "serving_default" # adjust if
|
| 63 |
|
| 64 |
-
# Download the
|
| 65 |
model_dir = snapshot_download(MODEL_REPO)
|
| 66 |
-
|
| 67 |
|
| 68 |
class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
|
| 69 |
|
| 70 |
def predict_image(input_image: Image.Image):
|
| 71 |
-
"""Preprocess to
|
| 72 |
img = input_image.convert("RGB").resize((224, 224))
|
| 73 |
x = tf.keras.preprocessing.image.img_to_array(img)
|
| 74 |
x = tf.keras.applications.efficientnet.preprocess_input(x)
|
| 75 |
-
x = tf.expand_dims(x, 0)
|
| 76 |
|
| 77 |
-
outputs =
|
| 78 |
if isinstance(outputs, dict) and outputs:
|
| 79 |
preds = outputs[next(iter(outputs))]
|
| 80 |
else:
|
|
@@ -116,13 +112,13 @@ def safe_load_all_urls(urls):
|
|
| 116 |
|
| 117 |
all_loaded_docs = safe_load_all_urls(URLS)
|
| 118 |
|
| 119 |
-
# Smaller base chunks
|
| 120 |
-
|
| 121 |
chunk_size=700,
|
| 122 |
chunk_overlap=80,
|
| 123 |
length_function=len,
|
| 124 |
)
|
| 125 |
-
docs =
|
| 126 |
|
| 127 |
# Embeddings
|
| 128 |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
|
|
@@ -138,20 +134,27 @@ vectordb = Chroma.from_documents(
|
|
| 138 |
)
|
| 139 |
|
| 140 |
# Base retriever
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# Hard-cap tokens in retrieved docs (~200 tokens per slice)
|
| 144 |
-
token_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=30)
|
| 145 |
-
compressor = DocumentCompressorPipeline(transformers=[token_splitter])
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
compression_retriever = ContextualCompressionRetriever(
|
| 148 |
-
base_retriever=
|
| 149 |
base_compressor=compressor,
|
| 150 |
)
|
| 151 |
|
| 152 |
|
| 153 |
# ======================================
|
| 154 |
-
# 3) PROMPT &
|
| 155 |
# ======================================
|
| 156 |
class FinalAnswer(BaseModel):
|
| 157 |
question: str = Field(description="User question")
|
|
@@ -174,7 +177,7 @@ qa_prompt = ChatPromptTemplate.from_template(
|
|
| 174 |
|
| 175 |
|
| 176 |
# =============================
|
| 177 |
-
# 4) LLM
|
| 178 |
# =============================
|
| 179 |
LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
|
| 180 |
tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
|
|
@@ -185,7 +188,7 @@ gen = pipeline(
|
|
| 185 |
model=mdl,
|
| 186 |
tokenizer=tok,
|
| 187 |
max_new_tokens=512,
|
| 188 |
-
do_sample=False, # deterministic;
|
| 189 |
)
|
| 190 |
llm = HuggingFacePipeline(pipeline=gen)
|
| 191 |
|
|
@@ -200,7 +203,7 @@ memory = ConversationBufferMemory(
|
|
| 200 |
|
| 201 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 202 |
llm=llm,
|
| 203 |
-
retriever=compression_retriever,
|
| 204 |
memory=memory,
|
| 205 |
verbose=True,
|
| 206 |
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
|
@@ -210,7 +213,7 @@ qa_chain = ConversationalRetrievalChain.from_llm(
|
|
| 210 |
)
|
| 211 |
|
| 212 |
def _safe_json_extract(raw: str, question: str) -> dict:
|
| 213 |
-
"""Try strict JSON; otherwise extract first {...}
|
| 214 |
raw = (raw or "").strip()
|
| 215 |
try:
|
| 216 |
return json.loads(raw)
|
|
@@ -279,8 +282,4 @@ app = gr.TabbedInterface(
|
|
| 279 |
)
|
| 280 |
|
| 281 |
app.queue()
|
| 282 |
-
app.launch(
|
| 283 |
-
server_name="0.0.0.0",
|
| 284 |
-
server_port=7860,
|
| 285 |
-
share=os.environ.get("GRADIO_SHARE", "false").lower() == "true",
|
| 286 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
=========================================================
|
| 3 |
+
app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2)
|
| 4 |
=========================================================
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import os
|
| 8 |
import json
|
| 9 |
import shutil
|
|
|
|
| 16 |
import tenacity
|
| 17 |
from fake_useragent import UserAgent
|
| 18 |
|
| 19 |
+
# --- LangChain v0.2 family ---
|
| 20 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 21 |
from langchain_core.prompts import ChatPromptTemplate
|
| 22 |
from langchain_core.output_parsers import PydanticOutputParser
|
|
|
|
|
|
|
| 23 |
from langchain.chains import ConversationalRetrievalChain
|
| 24 |
from langchain.memory import ConversationBufferMemory
|
| 25 |
+
from langchain_community.document_loaders import WebBaseLoader
|
| 26 |
+
from langchain_community.vectorstores import Chroma
|
| 27 |
|
| 28 |
+
# Embeddings (prefer langchain-huggingface if installed; fallback to community)
|
| 29 |
try:
|
| 30 |
from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
|
| 31 |
except ImportError:
|
|
|
|
| 35 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 36 |
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
|
| 37 |
|
|
|
|
| 38 |
from pydantic.v1 import BaseModel, Field
|
| 39 |
|
| 40 |
+
# HF Hub for downloading the SavedModel once
|
| 41 |
from huggingface_hub import snapshot_download
|
| 42 |
|
| 43 |
+
# Local transformers pipeline (no API token required)
|
| 44 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 45 |
from langchain_community.llms import HuggingFacePipeline
|
| 46 |
|
| 47 |
+
# Theming + URL list
|
| 48 |
import theme
|
| 49 |
from url_list import URLS
|
| 50 |
|
|
|
|
|
|
|
| 51 |
theme = theme.Theme()
|
| 52 |
|
| 53 |
|
| 54 |
# =========================================================
|
| 55 |
+
# 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
|
| 56 |
# =========================================================
|
| 57 |
MODEL_REPO = "rocioadlc/efficientnetB0_trash"
|
| 58 |
+
MODEL_SERVING_SIGNATURE = "serving_default" # adjust if the model exposes a different endpoint
|
| 59 |
|
| 60 |
+
# Download the model snapshot and wrap it via TFSMLayer (Keras 3 compatible)
|
| 61 |
model_dir = snapshot_download(MODEL_REPO)
|
| 62 |
+
image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
|
| 63 |
|
| 64 |
class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
|
| 65 |
|
| 66 |
def predict_image(input_image: Image.Image):
|
| 67 |
+
"""Preprocess to EfficientNetB0 input (224x224) and run inference."""
|
| 68 |
img = input_image.convert("RGB").resize((224, 224))
|
| 69 |
x = tf.keras.preprocessing.image.img_to_array(img)
|
| 70 |
x = tf.keras.applications.efficientnet.preprocess_input(x)
|
| 71 |
+
x = tf.expand_dims(x, 0)
|
| 72 |
|
| 73 |
+
outputs = image_model(x)
|
| 74 |
if isinstance(outputs, dict) and outputs:
|
| 75 |
preds = outputs[next(iter(outputs))]
|
| 76 |
else:
|
|
|
|
| 112 |
|
| 113 |
all_loaded_docs = safe_load_all_urls(URLS)
|
| 114 |
|
| 115 |
+
# Smaller base chunks so downstream compression has less work
|
| 116 |
+
base_splitter = RecursiveCharacterTextSplitter(
|
| 117 |
chunk_size=700,
|
| 118 |
chunk_overlap=80,
|
| 119 |
length_function=len,
|
| 120 |
)
|
| 121 |
+
docs = base_splitter.split_documents(all_loaded_docs)
|
| 122 |
|
| 123 |
# Embeddings
|
| 124 |
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
# Base retriever
|
| 137 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
# --- Context compression to keep inputs under FLAN-T5 512-token limit ---
|
| 140 |
+
# Prefer token-aware splitter; fall back to char splitter if `tiktoken` isn't installed.
|
| 141 |
+
try:
|
| 142 |
+
from langchain_text_splitters import TokenTextSplitter
|
| 143 |
+
splitter_for_compression = TokenTextSplitter(chunk_size=200, chunk_overlap=30) # needs `tiktoken`
|
| 144 |
+
except Exception:
|
| 145 |
+
# Fallback that doesn't require tiktoken
|
| 146 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
|
| 147 |
+
splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
|
| 148 |
+
|
| 149 |
+
compressor = DocumentCompressorPipeline(transformers=[splitter_for_compression])
|
| 150 |
compression_retriever = ContextualCompressionRetriever(
|
| 151 |
+
base_retriever=retriever,
|
| 152 |
base_compressor=compressor,
|
| 153 |
)
|
| 154 |
|
| 155 |
|
| 156 |
# ======================================
|
| 157 |
+
# 3) PROMPT & Pydantic schema parsing
|
| 158 |
# ======================================
|
| 159 |
class FinalAnswer(BaseModel):
|
| 160 |
question: str = Field(description="User question")
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
# =============================
|
| 180 |
+
# 4) LLM — local, token-free
|
| 181 |
# =============================
|
| 182 |
LOCAL_MODEL_ID = os.environ.get("LOCAL_LLM", "google/flan-t5-base")
|
| 183 |
tok = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
|
|
|
|
| 188 |
model=mdl,
|
| 189 |
tokenizer=tok,
|
| 190 |
max_new_tokens=512,
|
| 191 |
+
do_sample=False, # deterministic; helps JSON adherence
|
| 192 |
)
|
| 193 |
llm = HuggingFacePipeline(pipeline=gen)
|
| 194 |
|
|
|
|
| 203 |
|
| 204 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 205 |
llm=llm,
|
| 206 |
+
retriever=compression_retriever,
|
| 207 |
memory=memory,
|
| 208 |
verbose=True,
|
| 209 |
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
def _safe_json_extract(raw: str, question: str) -> dict:
|
| 216 |
+
"""Try strict JSON; otherwise extract first {...}; fallback to plain text."""
|
| 217 |
raw = (raw or "").strip()
|
| 218 |
try:
|
| 219 |
return json.loads(raw)
|
|
|
|
| 282 |
)
|
| 283 |
|
| 284 |
app.queue()
|
| 285 |
+
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -28,3 +28,5 @@ fastapi==0.115.0
|
|
| 28 |
starlette==0.38.2
|
| 29 |
pydantic==2.8.2
|
| 30 |
pydantic-core==2.20.1
|
|
|
|
|
|
|
|
|
| 28 |
starlette==0.38.2
|
| 29 |
pydantic==2.8.2
|
| 30 |
pydantic-core==2.20.1
|
| 31 |
+
|
| 32 |
+
tiktoken>=0.5.2
|