File size: 3,057 Bytes
3df3bd5
 
317d9f9
3df3bd5
 
317d9f9
 
3df3bd5
317d9f9
 
 
 
 
 
 
 
3df3bd5
 
 
 
 
 
317d9f9
 
 
 
3df3bd5
317d9f9
 
 
 
 
 
 
 
 
 
 
3df3bd5
 
 
317d9f9
 
3df3bd5
317d9f9
3df3bd5
 
 
 
 
317d9f9
 
 
 
3df3bd5
 
 
 
317d9f9
3df3bd5
317d9f9
3df3bd5
 
317d9f9
3df3bd5
 
 
 
 
 
317d9f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
from PIL import Image

# --------- Load BLIP (Image Captioning) ---------
BLIP_MODEL = "Salesforce/blip-image-captioning-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained(BLIP_MODEL)
blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(device)

# --------- Load Small Language Model for reasoning ---------
LLM_MODEL = "tiiuae/falcon-7b-instruct"  # or "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
llm_model.to(device)

DISCLAIMER = (
    "⚠️ **Disclaimer:** This tool provides general information and is **not** a substitute for "
    "official emergency guidance. In an emergency, follow directions from local authorities."
)

def generate_caption(image: Image.Image) -> str:
    inputs = processor(images=image, return_tensors="pt").to(device)
    out = blip_model.generate(**inputs, max_new_tokens=64)
    return processor.decode(out[0], skip_special_tokens=True)

def generate_precautions(caption: str, user_question: str) -> str:
    prompt = (
        f"The image shows: {caption}. "
        f"Based on this, identify what type of natural disaster this is and provide "
        f"immediate and long-term precautionary measures. {user_question}"
    )
    inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
    output_ids = llm_model.generate(**inputs, max_new_tokens=256)
    return llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)

def analyze_image(image, user_question, history):
    if image is None:
        return history, "❌ Please upload an image."

    caption = generate_caption(image)
    answer = generate_precautions(caption, user_question or "")
    answer += "\n\n" + DISCLAIMER
    history.append((user_question if user_question else caption, answer))
    return history, ""

def clear_chat():
    return [], ""

# --------- Gradio UI ---------
with gr.Blocks(title="Disaster Precaution Chatbot") as demo:
    gr.Markdown("# 🌪️ Disaster Precaution Chatbot (BLIP + LLM)")
    gr.Markdown("Upload a disaster image and get advice on how to handle it.")
    gr.Markdown(DISCLAIMER)

    state = gr.State([])
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Upload Disaster Image")
            txt_input = gr.Textbox(label="Your question (optional)", placeholder="What should I do?")
            analyze_btn = gr.Button("Analyze Image", variant="primary")
            clear_btn = gr.Button("Clear Chat")
        chatbot = gr.Chatbot(label="Chatbot", height=400)

    analyze_btn.click(analyze_image, [img_input, txt_input, state], [chatbot, txt_input])
    clear_btn.click(clear_chat, [], [chatbot, txt_input])

if __name__ == "__main__":
    demo.launch()