ALVHB95 commited on
Commit
3a7b6f0
·
1 Parent(s): 9150cfd

new model

Browse files
Files changed (3) hide show
  1. Dockerfile.txt +20 -3
  2. app.py +47 -44
  3. requirements.txt +1 -1
Dockerfile.txt CHANGED
@@ -1,11 +1,28 @@
1
  FROM python:3.10
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  WORKDIR /code
4
 
5
  COPY ./requirements.txt /code/requirements.txt
6
-
7
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
 
9
  COPY . .
10
 
11
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
1
  FROM python:3.10
2
 
3
+ # Prevent Python from writing .pyc files / enable unbuffered logs
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ # Make Gradio listen on all interfaces and on port 7860
8
+ ENV GRADIO_SERVER_NAME=0.0.0.0
9
+ ENV GRADIO_SERVER_PORT=7860
10
+
11
+ # Optional but recommended
12
+ # ENV HUGGINGFACEHUB_API_TOKEN=hf_xxx
13
+ # ENV USER_AGENT="green-greta/1.0 (+contact-or-repo)"
14
+
15
  WORKDIR /code
16
 
17
  COPY ./requirements.txt /code/requirements.txt
18
+ RUN pip install --no-cache-dir --upgrade pip && \
19
+ pip install --no-cache-dir --upgrade -r /code/requirements.txt
20
 
21
  COPY . .
22
 
23
+ # Expose Gradio port
24
+ EXPOSE 7860
25
+
26
+ # Your code calls app.launch(...) inside app.py, so just run Python.
27
+ # (Uvicorn is for FastAPI apps, which you are not using here.)
28
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,15 +1,13 @@
1
  """
2
  =========================================================
3
- Fixed app.py — Green Greta (Gradio + HF + LangChain)
4
  Notes:
5
- - Replaced deprecated/404-prone HuggingFaceHub call with HuggingFaceEndpoint.
6
- Option A (default below): use a readily available public model (Zephyr) via the free Inference API.
7
- Option B (commented): keep Mixtral, but then you MUST provision an Inference Endpoint (paid) and set endpoint_url.
8
- - Made JSON parsing of the schema robust; no fragile string slicing.
9
- - Fixed EfficientNet input size bug (224x224, not 244x224).
10
- - Safer memory setup for ConversationalRetrievalChain; added return_messages.
11
- - Better error handling on web loads and QA call.
12
- - Minor cleanups of duplicate imports, warnings, and defaults.
13
  =========================================================
14
  """
15
 
@@ -19,21 +17,22 @@ import shutil
19
 
20
  import gradio as gr
21
  import tensorflow as tf
22
- from tensorflow import keras
23
  from PIL import Image
24
 
25
  import tenacity # for retrying failed requests
26
  from fake_useragent import UserAgent
27
 
28
- # LangChain
29
- from langchain.text_splitter import RecursiveCharacterTextSplitter
30
- from langchain.embeddings import HuggingFaceEmbeddings
31
- from langchain.prompts import ChatPromptTemplate
32
- from langchain.output_parsers import PydanticOutputParser
33
- from langchain.chains import ConversationalRetrievalChain
34
- from langchain.memory import ConversationBufferMemory
35
  from langchain_community.document_loaders import WebBaseLoader
36
  from langchain_community.llms import HuggingFaceEndpoint
 
 
 
 
37
  from pydantic.v1 import BaseModel, Field
38
 
39
  # Theming
@@ -54,7 +53,7 @@ from huggingface_hub import from_pretrained_keras
54
  model1 = from_pretrained_keras("rocioadlc/efficientnetB0_trash")
55
 
56
  # Define class labels for the trash classification
57
- class_labels = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
58
 
59
 
60
  def predict_image(input_image: Image.Image):
@@ -71,16 +70,15 @@ def predict_image(input_image: Image.Image):
71
 
72
  predictions = model1.predict(image_array)
73
  probs = predictions[0].tolist()
74
-
75
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
76
 
77
 
78
  image_gradio_app = gr.Interface(
79
  fn=predict_image,
80
- inputs=gr.Image(label="Image", sources=['upload', 'webcam'], type="pil"),
81
  outputs=[gr.Label(label="Result")],
82
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
83
- theme=theme
84
  )
85
 
86
  """
@@ -106,6 +104,7 @@ def safe_load_all_urls(urls):
106
  docs = load_url(link)
107
  all_docs.extend(docs)
108
  except Exception as e:
 
109
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
110
  return all_docs
111
 
@@ -121,16 +120,15 @@ text_splitter = RecursiveCharacterTextSplitter(
121
  docs = text_splitter.split_documents(all_loaded_docs)
122
 
123
  # Small + high-quality general embedding
124
- embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-small')
125
 
126
- persist_directory = 'docs/chroma/'
127
  shutil.rmtree(persist_directory, ignore_errors=True)
128
 
129
- from langchain.vectorstores import Chroma
130
  vectordb = Chroma.from_documents(
131
  documents=docs,
132
  embedding=embeddings,
133
- persist_directory=persist_directory
134
  )
135
 
136
  retriever = vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr")
@@ -145,6 +143,7 @@ class FinalAnswer(BaseModel):
145
  question: str = Field(description="User question")
146
  answer: str = Field(description="Direct answer")
147
 
 
148
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
149
 
150
  SYSTEM_TEMPLATE = (
@@ -169,14 +168,12 @@ qa_prompt = ChatPromptTemplate.from_template(
169
 
170
  """
171
  =========================================================
172
- 4) LLM SETUP (fixes the 404/deprecation issue)
173
  =========================================================
174
  """
175
  # IMPORTANT:
176
- # The previous code used `HuggingFaceHub` with repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1".
177
- # That route often 404s on the public Inference API router unless you deploy an Endpoint.
178
- # Fix: use `HuggingFaceEndpoint` with a public model that *is* available on the router,
179
- # or provision your own Inference Endpoint if you insist on Mixtral.
180
 
181
  # ---- Option A (DEFAULT): public, free router model that works out-of-the-box
182
  DEFAULT_REPO = os.environ.get("HF_REPO_ID", "HuggingFaceH4/zephyr-7b-beta")
@@ -189,11 +186,11 @@ llm = HuggingFaceEndpoint(
189
  top_k=50,
190
  repetition_penalty=1.05,
191
  do_sample=True,
192
- # Make sure your token is set in env: HUGGINGFACEHUB_API_TOKEN
193
  )
194
 
195
- # ---- Option B (MIXTRAL): requires a paid Inference Endpoint you own
196
- # MIXTRAL_ENDPOINT_URL = os.environ.get("HF_ENDPOINT_URL") # e.g. https://xyz.us-east-1.aws.endpoints.huggingface.cloud
197
  # if MIXTRAL_ENDPOINT_URL:
198
  # llm = HuggingFaceEndpoint(
199
  # endpoint_url=MIXTRAL_ENDPOINT_URL,
@@ -221,36 +218,38 @@ qa_chain = ConversationalRetrievalChain.from_llm(
221
  retriever=retriever,
222
  memory=memory,
223
  verbose=True,
224
- combine_docs_chain_kwargs={'prompt': qa_prompt},
225
  get_chat_history=lambda h: h, # memory already returns messages
226
  rephrase_question=False,
227
- output_key='output',
228
  )
229
 
230
 
231
  def chat_interface(question, history):
232
  """
233
  Processes the user's question through the qa_chain,
234
- and robustly parses the JSON output.
235
  """
236
  try:
237
- result = qa_chain.invoke({'question': question})
238
- raw = result.get('output', '').strip()
239
 
240
  # Try strict JSON first
241
  try:
242
  payload = json.loads(raw)
243
  except json.JSONDecodeError:
244
  # If the model returned extra text around JSON, try to extract the first JSON object
245
- start = raw.find('{')
246
- end = raw.rfind('}')
247
  if start != -1 and end != -1 and end > start:
248
- payload = json.loads(raw[start:end+1])
 
 
 
249
  else:
250
  payload = {"question": question, "answer": raw}
251
 
252
  # Enforce schema
253
- question_out = payload.get("question", question)
254
  answer_out = payload.get("answer", raw)
255
  return answer_out
256
 
@@ -312,5 +311,9 @@ app = gr.TabbedInterface(
312
 
313
  # Enable queue() for concurrency and launch the Gradio app
314
  app.queue()
315
- # Tip: set share=True if you want a public link
316
- app.launch(share=os.environ.get("GRADIO_SHARE", "false").lower() == "true")
 
 
 
 
 
1
  """
2
  =========================================================
3
+ Fixed app.py — Green Greta (Gradio + HF + LangChain v0.2)
4
  Notes:
5
+ - Uses HuggingFaceEndpoint with a public router model (Zephyr) by default.
6
+ - Robust JSON parsing (no fragile string slicing).
7
+ - EfficientNet input size fixed (224x224).
8
+ - LangChain v0.2 import layout (core/community/text-splitters).
9
+ - Safer memory for ConversationalRetrievalChain; better error handling.
10
+ - Gradio binds to 0.0.0.0:7860 for Docker.
 
 
11
  =========================================================
12
  """
13
 
 
17
 
18
  import gradio as gr
19
  import tensorflow as tf
 
20
  from PIL import Image
21
 
22
  import tenacity # for retrying failed requests
23
  from fake_useragent import UserAgent
24
 
25
+ # LangChain (v0.2+ layout)
26
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
27
+ from langchain_core.prompts import ChatPromptTemplate
28
+ from langchain_core.output_parsers import PydanticOutputParser
29
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
30
  from langchain_community.document_loaders import WebBaseLoader
31
  from langchain_community.llms import HuggingFaceEndpoint
32
+ from langchain_community.vectorstores import Chroma
33
+ from langchain.chains import ConversationalRetrievalChain
34
+ from langchain.memory import ConversationBufferMemory
35
+
36
  from pydantic.v1 import BaseModel, Field
37
 
38
  # Theming
 
53
  model1 = from_pretrained_keras("rocioadlc/efficientnetB0_trash")
54
 
55
  # Define class labels for the trash classification
56
+ class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
57
 
58
 
59
  def predict_image(input_image: Image.Image):
 
70
 
71
  predictions = model1.predict(image_array)
72
  probs = predictions[0].tolist()
 
73
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
74
 
75
 
76
  image_gradio_app = gr.Interface(
77
  fn=predict_image,
78
+ inputs=gr.Image(label="Image", sources=["upload", "webcam"], type="pil"),
79
  outputs=[gr.Label(label="Result")],
80
  title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
81
+ theme=theme,
82
  )
83
 
84
  """
 
104
  docs = load_url(link)
105
  all_docs.extend(docs)
106
  except Exception as e:
107
+ # If load_url fails after all retries, skip that URL
108
  print(f"Skipping URL due to error: {link}\nError: {e}\n")
109
  return all_docs
110
 
 
120
  docs = text_splitter.split_documents(all_loaded_docs)
121
 
122
  # Small + high-quality general embedding
123
+ embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
124
 
125
+ persist_directory = "docs/chroma/"
126
  shutil.rmtree(persist_directory, ignore_errors=True)
127
 
 
128
  vectordb = Chroma.from_documents(
129
  documents=docs,
130
  embedding=embeddings,
131
+ persist_directory=persist_directory,
132
  )
133
 
134
  retriever = vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr")
 
143
  question: str = Field(description="User question")
144
  answer: str = Field(description="Direct answer")
145
 
146
+
147
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
148
 
149
  SYSTEM_TEMPLATE = (
 
168
 
169
  """
170
  =========================================================
171
+ 4) LLM SETUP (no router 404s)
172
  =========================================================
173
  """
174
  # IMPORTANT:
175
+ # The old route "mistralai/Mixtral-8x7B-Instruct-v0.1" often 404s on the public HF router.
176
+ # Use a router-available model OR your own paid Inference Endpoint.
 
 
177
 
178
  # ---- Option A (DEFAULT): public, free router model that works out-of-the-box
179
  DEFAULT_REPO = os.environ.get("HF_REPO_ID", "HuggingFaceH4/zephyr-7b-beta")
 
186
  top_k=50,
187
  repetition_penalty=1.05,
188
  do_sample=True,
189
+ # Set env: HUGGINGFACEHUB_API_TOKEN=hf_xxx
190
  )
191
 
192
+ # ---- Option B (MIXTRAL): your paid Inference Endpoint
193
+ # MIXTRAL_ENDPOINT_URL = os.environ.get("HF_ENDPOINT_URL") # e.g. https://xyz.aws.endpoints.huggingface.cloud
194
  # if MIXTRAL_ENDPOINT_URL:
195
  # llm = HuggingFaceEndpoint(
196
  # endpoint_url=MIXTRAL_ENDPOINT_URL,
 
218
  retriever=retriever,
219
  memory=memory,
220
  verbose=True,
221
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
222
  get_chat_history=lambda h: h, # memory already returns messages
223
  rephrase_question=False,
224
+ output_key="output",
225
  )
226
 
227
 
228
  def chat_interface(question, history):
229
  """
230
  Processes the user's question through the qa_chain,
231
+ and robustly parses the JSON output per schema.
232
  """
233
  try:
234
+ result = qa_chain.invoke({"question": question})
235
+ raw = result.get("output", "").strip()
236
 
237
  # Try strict JSON first
238
  try:
239
  payload = json.loads(raw)
240
  except json.JSONDecodeError:
241
  # If the model returned extra text around JSON, try to extract the first JSON object
242
+ start = raw.find("{")
243
+ end = raw.rfind("}")
244
  if start != -1 and end != -1 and end > start:
245
+ try:
246
+ payload = json.loads(raw[start : end + 1])
247
+ except json.JSONDecodeError:
248
+ payload = {"question": question, "answer": raw}
249
  else:
250
  payload = {"question": question, "answer": raw}
251
 
252
  # Enforce schema
 
253
  answer_out = payload.get("answer", raw)
254
  return answer_out
255
 
 
311
 
312
  # Enable queue() for concurrency and launch the Gradio app
313
  app.queue()
314
+ # Tip: set GRADIO_SHARE=true in env if you want a public link
315
+ app.launch(
316
+ server_name="0.0.0.0",
317
+ server_port=7860,
318
+ share=os.environ.get("GRADIO_SHARE", "false").lower() == "true",
319
+ )
requirements.txt CHANGED
@@ -9,7 +9,7 @@ tensorflow==2.13.0
9
  langchain==0.2.12
10
  langchain-community==0.2.10
11
  langchain-text-splitters==0.2.2
12
- langchain-core==0.2.24
13
 
14
  # Vector store
15
  chromadb==0.5.3
 
9
  langchain==0.2.12
10
  langchain-community==0.2.10
11
  langchain-text-splitters==0.2.2
12
+ langchain-core==0.2.27
13
 
14
  # Vector store
15
  chromadb==0.5.3