SDGsBot / app.py
adzee17's picture
Update app.py
0783fbf verified
import os
import gradio as gr
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import openai
# --- SETUP ---
openai.api_key = os.getenv("")
# Load pretrained ResNet50 model
model = models.resnet50(pretrained=True)
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Fallback label list for environments where downloading fails
fallback_labels = [
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead", # ... truncated
"fire engine", "garbage truck", "pickup", "tow truck", "trailer truck"
]
try:
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
local_path = torch.hub.download_url_to_file(LABELS_URL, 'imagenet_classes.txt')
with open(local_path) as f:
imagenet_labels = [line.strip() for line in f]
except Exception as e:
print("Warning: Failed to download imagenet labels. Using fallback list.")
imagenet_labels = fallback_labels
# Mock SDG mapping (you can expand this)
sdg_mapping = {
"garbage truck": "SDG 11 - Sustainable Cities",
"water": "SDG 6 - Clean Water",
"tree": "SDG 13 - Climate Action",
"smoke": "SDG 3 - Good Health"
}
# --- FUNCTIONS ---
def classify_image(img):
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
_, predicted = outputs.max(1)
label = imagenet_labels[predicted.item() % len(imagenet_labels)] # Fixed bracket error
matched_sdg = next((v for k, v in sdg_mapping.items() if k in label.lower()), "SDG 12 - Responsible Consumption")
return label, matched_sdg
def get_indicators(sdg):
indicators = {
"SDG 11 - Sustainable Cities": ["Access to transport", "Waste collection"],
"SDG 6 - Clean Water": ["Water quality", "Wastewater treatment"],
"SDG 13 - Climate Action": ["Carbon emissions", "Reforestation"],
"SDG 3 - Good Health": ["Air quality", "Pollution levels"],
"SDG 12 - Responsible Consumption": ["Recycling", "Resource usage"]
}
return indicators.get(sdg, ["No indicators found"])
def get_environmental_impact(sdg):
impacts = {
"SDG 11 - Sustainable Cities": "Reduces urban waste and improves living conditions.",
"SDG 6 - Clean Water": "Improves access to clean and safe water.",
"SDG 13 - Climate Action": "Combats climate change through ecosystem health.",
"SDG 3 - Good Health": "Promotes healthier communities by reducing pollutants.",
"SDG 12 - Responsible Consumption": "Encourages sustainable use of resources."
}
return impacts.get(sdg, "Helps achieve sustainable development.")
def problems_resolved(sdg):
problems = {
"SDG 11 - Sustainable Cities": ["Garbage overflow", "Unplanned development"],
"SDG 6 - Clean Water": ["Dirty water", "No filtration"],
"SDG 13 - Climate Action": ["Deforestation", "Pollution"],
"SDG 3 - Good Health": ["Air pollution", "Toxic waste"],
"SDG 12 - Responsible Consumption": ["Overuse of materials", "Lack of recycling"]
}
return problems.get(sdg, ["General sustainability issues"])
def chat_with_user(message):
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant for environmental reporting."},
{"role": "user", "content": message}
]
)
return response.choices[0].message["content"]
except Exception as e:
return f"[Error contacting OpenAI API: {str(e)}]"
# --- UI COMPONENTS ---
countries = ["Pakistan", "India", "Bangladesh", "Nepal"]
cities = ["Islamabad", "Lahore", "Karachi", "Peshawar", "Quetta", "Delhi", "Dhaka", "Kathmandu"]
with gr.Blocks() as demo:
with gr.Tab("Public Reporter"):
country = gr.Dropdown(choices=countries, label="Select Country")
city = gr.Dropdown(choices=cities, label="Select City")
area = gr.Textbox(label="Area")
desc = gr.Textbox(label="Issue Description")
image = gr.Image(type="pil")
submit_btn = gr.Button("Submit Report")
output_label = gr.Textbox(label="Detected Issue")
output_sdg = gr.Textbox(label="Relevant SDG")
output_indicators = gr.Dropdown(label="Indicators")
output_impact = gr.Textbox(label="Environmental Impact")
output_problems = gr.Textbox(label="Problems Resolved")
def process_report(img):
label, sdg = classify_image(img)
return label, sdg, get_indicators(sdg), get_environmental_impact(sdg), ", ".join(problems_resolved(sdg))
submit_btn.click(fn=process_report, inputs=[image],
outputs=[output_label, output_sdg, output_indicators, output_impact, output_problems])
chatbot = gr.ChatInterface(fn=chat_with_user, chatbot=gr.Chatbot(), title="Ask About Reporting")
with gr.Tab("Government Officials"):
gr.Markdown("## Submitted Reports (Simulated)")
gr.Dataframe(headers=["Country", "City", "Area", "Issue", "SDG", "Status"],
value=[["Pakistan", "Lahore", "Gulberg", "Garbage", "SDG 11", "Pending"]])
gr.Markdown("### Suggested Resources: Garbage truck, Cleanup crew")
gr.Markdown("### View by SDG, City, or Status coming soon...")
# --- LAUNCH ---
demo.launch()