Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| # --------- Load BLIP (Image Captioning) --------- | |
| BLIP_MODEL = "Salesforce/blip-image-captioning-base" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = BlipProcessor.from_pretrained(BLIP_MODEL) | |
| blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL).to(device) | |
| # --------- Load Small Language Model for reasoning --------- | |
| LLM_MODEL = "tiiuae/falcon-7b-instruct" # or "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) | |
| llm_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
| llm_model.to(device) | |
| DISCLAIMER = ( | |
| "⚠️ **Disclaimer:** This tool provides general information and is **not** a substitute for " | |
| "official emergency guidance. In an emergency, follow directions from local authorities." | |
| ) | |
| def generate_caption(image: Image.Image) -> str: | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| out = blip_model.generate(**inputs, max_new_tokens=64) | |
| return processor.decode(out[0], skip_special_tokens=True) | |
| def generate_precautions(caption: str, user_question: str) -> str: | |
| prompt = ( | |
| f"The image shows: {caption}. " | |
| f"Based on this, identify what type of natural disaster this is and provide " | |
| f"immediate and long-term precautionary measures. {user_question}" | |
| ) | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| output_ids = llm_model.generate(**inputs, max_new_tokens=256) | |
| return llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| def analyze_image(image, user_question, history): | |
| if image is None: | |
| return history, "❌ Please upload an image." | |
| caption = generate_caption(image) | |
| answer = generate_precautions(caption, user_question or "") | |
| answer += "\n\n" + DISCLAIMER | |
| history.append((user_question if user_question else caption, answer)) | |
| return history, "" | |
| def clear_chat(): | |
| return [], "" | |
| # --------- Gradio UI --------- | |
| with gr.Blocks(title="Disaster Precaution Chatbot") as demo: | |
| gr.Markdown("# 🌪️ Disaster Precaution Chatbot (BLIP + LLM)") | |
| gr.Markdown("Upload a disaster image and get advice on how to handle it.") | |
| gr.Markdown(DISCLAIMER) | |
| state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Upload Disaster Image") | |
| txt_input = gr.Textbox(label="Your question (optional)", placeholder="What should I do?") | |
| analyze_btn = gr.Button("Analyze Image", variant="primary") | |
| clear_btn = gr.Button("Clear Chat") | |
| chatbot = gr.Chatbot(label="Chatbot", height=400) | |
| analyze_btn.click(analyze_image, [img_input, txt_input, state], [chatbot, txt_input]) | |
| clear_btn.click(clear_chat, [], [chatbot, txt_input]) | |
| if __name__ == "__main__": | |
| demo.launch() | |