disaster-app / app.py
iffazainab's picture
Update app.py
317d9f9 verified
import gradio as gr
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
# --------- Load BLIP (Image Captioning) ---------
BLIP_MODEL = "Salesforce/blip-image-captioning-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained(BLIP_MODEL)
blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(device)
# --------- Load Small Language Model for reasoning ---------
LLM_MODEL = "tiiuae/falcon-7b-instruct" # or "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
llm_model.to(device)
DISCLAIMER = (
"⚠️ **Disclaimer:** This tool provides general information and is **not** a substitute for "
"official emergency guidance. In an emergency, follow directions from local authorities."
)
def generate_caption(image: Image.Image) -> str:
inputs = processor(images=image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs, max_new_tokens=64)
return processor.decode(out[0], skip_special_tokens=True)
def generate_precautions(caption: str, user_question: str) -> str:
prompt = (
f"The image shows: {caption}. "
f"Based on this, identify what type of natural disaster this is and provide "
f"immediate and long-term precautionary measures. {user_question}"
)
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
output_ids = llm_model.generate(**inputs, max_new_tokens=256)
return llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
def analyze_image(image, user_question, history):
if image is None:
return history, "❌ Please upload an image."
caption = generate_caption(image)
answer = generate_precautions(caption, user_question or "")
answer += "\n\n" + DISCLAIMER
history.append((user_question if user_question else caption, answer))
return history, ""
def clear_chat():
return [], ""
# --------- Gradio UI ---------
with gr.Blocks(title="Disaster Precaution Chatbot") as demo:
gr.Markdown("# 🌪️ Disaster Precaution Chatbot (BLIP + LLM)")
gr.Markdown("Upload a disaster image and get advice on how to handle it.")
gr.Markdown(DISCLAIMER)
state = gr.State([])
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload Disaster Image")
txt_input = gr.Textbox(label="Your question (optional)", placeholder="What should I do?")
analyze_btn = gr.Button("Analyze Image", variant="primary")
clear_btn = gr.Button("Clear Chat")
chatbot = gr.Chatbot(label="Chatbot", height=400)
analyze_btn.click(analyze_image, [img_input, txt_input, state], [chatbot, txt_input])
clear_btn.click(clear_chat, [], [chatbot, txt_input])
if __name__ == "__main__":
demo.launch()