ALVHB95 commited on
Commit
cfa2249
·
1 Parent(s): eb3e5ca
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -25,7 +25,7 @@ from langchain.memory import ConversationBufferMemory
25
  from langchain_community.document_loaders import WebBaseLoader
26
  from langchain_community.vectorstores import Chroma
27
 
28
- # Embeddings (prefer langchain-huggingface if installed; fallback to community)
29
  try:
30
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
31
  except ImportError:
@@ -37,12 +37,11 @@ from langchain.retrievers.document_compressors import DocumentCompressorPipeline
37
 
38
  from pydantic.v1 import BaseModel, Field
39
 
40
- # HF Hub for downloading the SavedModel once
41
  from huggingface_hub import snapshot_download
42
 
43
- # Local transformers pipeline (no API token required)
44
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # (still imported; not used in Hub mode)
45
- from langchain_community.llms import HuggingFacePipeline, HuggingFaceHub # <-- ADDED: HuggingFaceHub import
46
 
47
  # Theming + URL list
48
  import theme
@@ -50,21 +49,27 @@ from url_list import URLS
50
 
51
  theme = theme.Theme()
52
 
 
 
 
 
 
 
53
 
54
  # =========================================================
55
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
56
  # =========================================================
57
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
58
- MODEL_SERVING_SIGNATURE = "serving_default" # adjust if the model exposes a different endpoint
59
 
60
- # Download the model snapshot and wrap it via TFSMLayer (Keras 3 compatible)
61
  model_dir = snapshot_download(MODEL_REPO)
62
  image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
63
 
64
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
65
 
66
  def predict_image(input_image: Image.Image):
67
- """Preprocess to EfficientNetB0 input (224x224) and run inference."""
68
  img = input_image.convert("RGB").resize((224, 224))
69
  x = tf.keras.preprocessing.image.img_to_array(img)
70
  x = tf.keras.applications.efficientnet.preprocess_input(x)
@@ -112,7 +117,7 @@ def safe_load_all_urls(urls):
112
 
113
  all_loaded_docs = safe_load_all_urls(URLS)
114
 
115
- # Smaller base chunks so downstream compression has less work
116
  base_splitter = RecursiveCharacterTextSplitter(
117
  chunk_size=700,
118
  chunk_overlap=80,
@@ -136,13 +141,11 @@ vectordb = Chroma.from_documents(
136
  # Base retriever
137
  retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
138
 
139
- # --- Context compression to keep inputs under FLAN-T5 512-token limit ---
140
- # Prefer token-aware splitter; fall back to char splitter if `tiktoken` isn't installed.
141
  try:
142
  from langchain_text_splitters import TokenTextSplitter
143
- splitter_for_compression = TokenTextSplitter(chunk_size=200, chunk_overlap=30) # needs `tiktoken`
144
  except Exception:
145
- # Fallback that doesn't require tiktoken
146
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
147
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
148
 
@@ -170,8 +173,6 @@ SYSTEM_TEMPLATE = (
170
  "{format_instructions}"
171
  )
172
 
173
- # NOTE: Your original pattern kept; if you prefer, you can also do:
174
- # ChatPromptTemplate.from_template(SYSTEM_TEMPLATE).partial(format_instructions=parser.get_format_instructions())
175
  qa_prompt = ChatPromptTemplate.from_template(
176
  SYSTEM_TEMPLATE,
177
  partial_variables={"format_instructions": parser.get_format_instructions()},
@@ -179,11 +180,10 @@ qa_prompt = ChatPromptTemplate.from_template(
179
 
180
 
181
  # =============================
182
- # 4) LLM — HuggingFace Hub (Mixtral)
183
  # =============================
184
- # REQUIREMENT: set env var HUGGINGFACEHUB_API_TOKEN
185
- # (Settings → Variables & secrets in your Space)
186
- llm = HuggingFaceHub(
187
  repo_id="mistralai/Mixtral-8x7B-v0.1",
188
  task="text-generation",
189
  model_kwargs={
@@ -191,8 +191,10 @@ llm = HuggingFaceHub(
191
  "top_k": 30,
192
  "temperature": 0.1,
193
  "repetition_penalty": 1.03,
194
- # You may also pass: "return_full_text": False
195
  },
 
 
196
  )
197
 
198
 
@@ -216,7 +218,7 @@ qa_chain = ConversationalRetrievalChain.from_llm(
216
  )
217
 
218
  def _safe_json_extract(raw: str, question: str) -> dict:
219
- """Try strict JSON; otherwise extract first {...}; fallback to plain text."""
220
  raw = (raw or "").strip()
221
  try:
222
  return json.loads(raw)
 
25
  from langchain_community.document_loaders import WebBaseLoader
26
  from langchain_community.vectorstores import Chroma
27
 
28
+ # Embeddings (prefer langchain-huggingface if installed; fallback a community)
29
  try:
30
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
31
  except ImportError:
 
37
 
38
  from pydantic.v1 import BaseModel, Field
39
 
40
+ # HF Hub for downloading the SavedModel once (image classifier)
41
  from huggingface_hub import snapshot_download
42
 
43
+ # === LLM endpoint moderno (compatible con huggingface_hub>=0.23) ===
44
+ from langchain_huggingface import HuggingFaceEndpoint # Opción 1
 
45
 
46
  # Theming + URL list
47
  import theme
 
49
 
50
  theme = theme.Theme()
51
 
52
+ # (Opcional) reducir telemetría/ruido en logs de Space
53
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
54
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
55
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
56
+ os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
57
+
58
 
59
  # =========================================================
60
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
61
  # =========================================================
62
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
63
+ MODEL_SERVING_SIGNATURE = "serving_default" # ajusta si el modelo expone otra firma
64
 
65
+ # Descarga el snapshot y envuélvelo con TFSMLayer (compatible Keras 3)
66
  model_dir = snapshot_download(MODEL_REPO)
67
  image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
68
 
69
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
70
 
71
  def predict_image(input_image: Image.Image):
72
+ """Preprocess a EfficientNetB0 (224x224) y ejecuta inferencia."""
73
  img = input_image.convert("RGB").resize((224, 224))
74
  x = tf.keras.preprocessing.image.img_to_array(img)
75
  x = tf.keras.applications.efficientnet.preprocess_input(x)
 
117
 
118
  all_loaded_docs = safe_load_all_urls(URLS)
119
 
120
+ # Chunks base pequeños para que el compresor downstream trabaje menos
121
  base_splitter = RecursiveCharacterTextSplitter(
122
  chunk_size=700,
123
  chunk_overlap=80,
 
141
  # Base retriever
142
  retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
143
 
144
+ # --- Compresión de contexto para entradas ~512 tokens (t5/…); útil igual con Mixtral ---
 
145
  try:
146
  from langchain_text_splitters import TokenTextSplitter
147
+ splitter_for_compression = TokenTextSplitter(chunk_size=200, chunk_overlap=30) # requiere tiktoken
148
  except Exception:
 
149
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
150
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
151
 
 
173
  "{format_instructions}"
174
  )
175
 
 
 
176
  qa_prompt = ChatPromptTemplate.from_template(
177
  SYSTEM_TEMPLATE,
178
  partial_variables={"format_instructions": parser.get_format_instructions()},
 
180
 
181
 
182
  # =============================
183
+ # 4) LLM — Hugging Face Inference API (Mixtral)
184
  # =============================
185
+ # Requiere el secreto HUGGINGFACEHUB_API_TOKEN en el Space
186
+ llm = HuggingFaceEndpoint(
 
187
  repo_id="mistralai/Mixtral-8x7B-v0.1",
188
  task="text-generation",
189
  model_kwargs={
 
191
  "top_k": 30,
192
  "temperature": 0.1,
193
  "repetition_penalty": 1.03,
194
+ "return_full_text": False,
195
  },
196
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
197
+ timeout=120, # opcional
198
  )
199
 
200
 
 
218
  )
219
 
220
  def _safe_json_extract(raw: str, question: str) -> dict:
221
+ """Intenta JSON estricto; si falla, extrae el primer {...}; si no, texto plano."""
222
  raw = (raw or "").strip()
223
  try:
224
  return json.loads(raw)