Atharv Subhekar
commited on
Commit
·
592981e
1
Parent(s):
f76d03b
Application update
Browse files- .DS_Store +0 -0
- app.py +84 -13
- requirements.txt +2 -2
- sample_images/Screenshot 2024-06-28 at 1.35.57/342/200/257PM.png +0 -0
- ~$cumentation.docx +0 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
app.py
CHANGED
|
@@ -7,15 +7,10 @@ Original file is located at
|
|
| 7 |
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
|
| 8 |
"""
|
| 9 |
|
| 10 |
-
#!pip install gradio --quiet
|
| 11 |
-
#!pip install -Uq transformers datasets timm accelerate evaluate
|
| 12 |
-
|
| 13 |
-
import subprocess
|
| 14 |
-
# subprocess.run('pip3 install datasets timm cv2 huggingface_hub torch pillow matplotlib' ,shell=True)
|
| 15 |
-
|
| 16 |
import gradio as gr
|
| 17 |
-
from huggingface_hub import hf_hub_download
|
| 18 |
from safetensors.torch import load_model
|
|
|
|
|
|
|
| 19 |
from datasets import load_dataset
|
| 20 |
import torch
|
| 21 |
import torchvision.transforms as T
|
|
@@ -23,8 +18,17 @@ import cv2
|
|
| 23 |
import matplotlib.pyplot as plt
|
| 24 |
import numpy as np
|
| 25 |
from PIL import Image
|
| 26 |
-
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
|
|
@@ -52,8 +56,56 @@ def one_hot_decoding(labels):
|
|
| 52 |
true_labels.append(id2label[i])
|
| 53 |
return true_labels
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def model_output(image):
|
| 56 |
-
|
| 57 |
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| 58 |
|
| 59 |
img_size = (224,224)
|
|
@@ -72,8 +124,27 @@ def model_output(image):
|
|
| 72 |
pred_labels = one_hot_decoding(predictions)
|
| 73 |
output_text = " ".join(pred_labels)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
|
|
|
| 7 |
https://colab.research.google.com/drive/1HCITtw0z2BJO0z9GmsspLM1NzrLdft27
|
| 8 |
"""
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
|
|
|
| 11 |
from safetensors.torch import load_model
|
| 12 |
+
from timm import create_model
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
from datasets import load_dataset
|
| 15 |
import torch
|
| 16 |
import torchvision.transforms as T
|
|
|
|
| 18 |
import matplotlib.pyplot as plt
|
| 19 |
import numpy as np
|
| 20 |
from PIL import Image
|
| 21 |
+
import os
|
| 22 |
|
| 23 |
+
from langchain_community.document_loaders import TextLoader
|
| 24 |
+
from langchain_community.vectorstores import FAISS
|
| 25 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 26 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 27 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 28 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 29 |
+
from langchain_fireworks import ChatFireworks
|
| 30 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 31 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
| 32 |
|
| 33 |
|
| 34 |
safe_tensors = "model.safetensors" #hf_hub_download(repo_id="subhuatharva/swim-224-base-satellite-image-classification", filename="model.safetensors")
|
|
|
|
| 56 |
true_labels.append(id2label[i])
|
| 57 |
return true_labels
|
| 58 |
|
| 59 |
+
def ragChain():
|
| 60 |
+
"""
|
| 61 |
+
function: creates a rag chain
|
| 62 |
+
output: rag chain
|
| 63 |
+
"""
|
| 64 |
+
loader = TextLoader("document.txt")
|
| 65 |
+
docs = loader.load()
|
| 66 |
+
|
| 67 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
| 68 |
+
docs = text_splitter.split_documents(docs)
|
| 69 |
+
|
| 70 |
+
vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
|
| 71 |
+
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
|
| 72 |
+
|
| 73 |
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
| 74 |
+
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
|
| 75 |
+
|
| 76 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 77 |
+
[
|
| 78 |
+
(
|
| 79 |
+
"system",
|
| 80 |
+
"""You are a knowledgeable landscape deforestation analyst.
|
| 81 |
+
"""
|
| 82 |
+
),
|
| 83 |
+
(
|
| 84 |
+
"human",
|
| 85 |
+
"""First mention the detected labels only with short description.
|
| 86 |
+
Provide not more than 4 precautionary measures which are related to the detected labels that can be taken to control deforestation.
|
| 87 |
+
Don't include conversational messages.
|
| 88 |
+
""",
|
| 89 |
+
),
|
| 90 |
+
|
| 91 |
+
("human", "{context}, {question}"),
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
rag_chain = (
|
| 96 |
+
{
|
| 97 |
+
"context": retriever,
|
| 98 |
+
"question": RunnablePassthrough()
|
| 99 |
+
}
|
| 100 |
+
| prompt
|
| 101 |
+
| llm
|
| 102 |
+
| StrOutputParser()
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return rag_chain
|
| 106 |
+
|
| 107 |
def model_output(image):
|
| 108 |
+
|
| 109 |
PIL_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| 110 |
|
| 111 |
img_size = (224,224)
|
|
|
|
| 124 |
pred_labels = one_hot_decoding(predictions)
|
| 125 |
output_text = " ".join(pred_labels)
|
| 126 |
|
| 127 |
+
query = f"Detected labels in the provided satellite image are {output_text}. Give information on the labels."
|
| 128 |
+
|
| 129 |
+
return query
|
| 130 |
+
|
| 131 |
+
def generate_response(rag_chain, query):
|
| 132 |
+
"""
|
| 133 |
+
input: rag chain, query
|
| 134 |
+
function: generates response using llm and knowledge base
|
| 135 |
+
output: generated response by the llm
|
| 136 |
+
"""
|
| 137 |
+
return rag_chain.invoke(f"{query}")
|
| 138 |
+
|
| 139 |
+
def main(image):
|
| 140 |
+
query = model_output(image)
|
| 141 |
+
chain = ragChain()
|
| 142 |
+
output = generate_response(chain, query)
|
| 143 |
+
return output
|
| 144 |
+
title = "Satellite Image Landscape Analysis for Deforestation"
|
| 145 |
+
description = "This bot will take any satellite image and analyze the factors which lead to deforestation by identify the landscape based on forest areas, roads, habitation, water etc."
|
| 146 |
+
app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
|
| 147 |
+
description=description,
|
| 148 |
+
examples=[["sampleimages/train_142.jpg"], ["sampleimages/train_32.jpg"],["sampleimages/train_59.jpg"], ["sampleimages/train_67.jpg"],["sampleimages/train_75.jpg"],["sampleimages/train_92.jpg"],["sampleimages/random_satellite.jpg"]])
|
| 149 |
+
app.launch(share = True)
|
| 150 |
|
requirements.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
transformers
|
| 2 |
datasets
|
| 3 |
-
|
| 4 |
langchain-fireworks
|
| 5 |
langchain_core
|
| 6 |
langchain_community
|
|
@@ -10,4 +10,4 @@ safetensors
|
|
| 10 |
torch
|
| 11 |
torchvision
|
| 12 |
opencv-python
|
| 13 |
-
pillow
|
|
|
|
| 1 |
transformers
|
| 2 |
datasets
|
| 3 |
+
Time
|
| 4 |
langchain-fireworks
|
| 5 |
langchain_core
|
| 6 |
langchain_community
|
|
|
|
| 10 |
torch
|
| 11 |
torchvision
|
| 12 |
opencv-python
|
| 13 |
+
pillow
|
sample_images/Screenshot 2024-06-28 at 1.35.57/342/200/257PM.png
ADDED
|
~$cumentation.docx
DELETED
|
Binary file (162 Bytes)
|
|
|