import torch from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig ) import gradio as gr from PIL import Image import re import os from typing import List, Tuple # Create cache directory os.makedirs("model_cache", exist_ok=True) os.makedirs("examples", exist_ok=True) # Create examples directory # Configuration for 4-bit quantization quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) class RiverPollutionAnalyzer: def __init__(self): try: # Initialize BLIP for image captioning with caching self.blip_processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-base", cache_dir="model_cache" ) self.blip_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, device_map="auto", cache_dir="model_cache" ) # Initialize FLAN-T5-XL with quantization self.tokenizer = AutoTokenizer.from_pretrained( "google/flan-t5-xl", cache_dir="model_cache" ) self.model = AutoModelForSeq2SeqLM.from_pretrained( "google/flan-t5-xl", device_map="auto", quantization_config=quant_config, cache_dir="model_cache" ) except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") self.pollutants = [ "plastic waste", "chemical foam", "industrial discharge", "sewage water", "oil spill", "organic debris", "construction waste", "medical waste", "floating trash", "algal bloom", "toxic sludge", "agricultural runoff" ] self.severity_descriptions = { 1: "Minimal pollution - Slightly noticeable", 2: "Minor pollution - Small amounts visible", 3: "Moderate pollution - Clearly visible", 4: "Significant pollution - Affecting water quality", 5: "Heavy pollution - Obvious environmental impact", 6: "Severe pollution - Large accumulation", 7: "Very severe pollution - Major ecosystem impact", 8: "Extreme pollution - Dangerous levels", 9: "Critical pollution - Immediate action needed", 10: "Disaster level - Ecological catastrophe" } def analyze_image(self, image): """Two-step analysis: BLIP captioning + FLAN-T5 analysis""" if not isinstance(image, Image.Image): image = Image.fromarray(image) try: # Step 1: Generate image caption with BLIP inputs = self.blip_processor(image, return_tensors="pt").to(self.blip_model.device, torch.float16) caption = self.blip_model.generate(**inputs, max_new_tokens=100)[0] caption = self.blip_processor.decode(caption, skip_special_tokens=True) # Step 2: Analyze caption with FLAN-T5 prompt = f"""Analyze this river scene: '{caption}' 1. List visible pollutants from: {self.pollutants} 2. Estimate severity (1-10) Respond EXACTLY as: Pollutants: [comma separated list] Severity: [number]""" inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=200) analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True) pollutants, severity = self._parse_response(analysis) return self._format_analysis(pollutants, severity) except Exception as e: return f"โš ๏ธ Analysis failed: {str(e)}" def _parse_response(self, analysis: str) -> Tuple[List[str], int]: """Parse the model response into pollutants list and severity score""" pollutants = [] severity = 0 # Extract pollutants pollutants_match = re.search(r"Pollutants:\s*\[(.*?)\]", analysis) if pollutants_match: pollutants_str = pollutants_match.group(1) pollutants = [p.strip() for p in pollutants_str.split(",") if p.strip()] # Extract severity severity_match = re.search(r"Severity:\s*(\d+)", analysis) if severity_match: severity = int(severity_match.group(1)) # If parsing failed, fallback to calculating severity if not severity or severity < 1 or severity > 10: severity = self._calculate_severity(pollutants) return pollutants, severity def _calculate_severity(self, pollutants: List[str]) -> int: """Calculate severity based on pollutants""" if not pollutants: return 1 severity_map = { "plastic waste": 4, "chemical foam": 7, "industrial discharge": 8, "sewage water": 6, "oil spill": 9, "organic debris": 3, "construction waste": 5, "medical waste": 8, "floating trash": 4, "algal bloom": 6, "toxic sludge": 9, "agricultural runoff": 5 } base_score = sum(severity_map.get(p, 3) for p in pollutants) avg_score = base_score / len(pollutants) return min(10, max(1, round(avg_score))) def _format_analysis(self, pollutants: List[str], severity: int) -> str: """Format the analysis results into a markdown report""" if not pollutants: pollutants = ["No visible pollution detected"] pollutants_list = "\n".join(f"- {p}" for p in pollutants) severity_desc = self.severity_descriptions.get(severity, "Unknown severity level") return f""" ## Pollution Analysis Report ### Identified Pollutants: {pollutants_list} ### Severity Assessment: **Level {severity}/10** - {severity_desc} ### Recommended Actions: {self._get_recommendations(severity)} """ def _get_recommendations(self, severity: int) -> str: """Get recommendations based on severity level""" if severity <= 3: return "Monitor the situation. Consider community clean-up efforts." elif severity <= 5: return "Local authorities should investigate. Basic remediation needed." elif severity <= 7: return "Immediate containment required. Environmental assessment needed." elif severity <= 9: return "Emergency response required. Notify environmental agencies." else: return "Disaster response needed. Evacuation may be necessary." def analyze_chat(self, message: str) -> str: """Handle chat questions about pollution""" prompt = f"""You are an environmental expert. Answer this question about river pollution: {message} Provide a concise, factual response in under 100 words.""" inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=150) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response # Initialize with error handling try: analyzer = RiverPollutionAnalyzer() model_status = "โœ… Models loaded successfully" except Exception as e: analyzer = None model_status = f"โŒ Model loading failed: {str(e)}" # Gradio Interface css = """ .header { text-align: center; max-width: 800px; margin: auto; } .header img { max-width: 100%; } .side-by-side { display: flex; flex-wrap: wrap; gap: 20px; } .left-panel, .right-panel { flex: 1; min-width: 300px; } .analysis-box { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-top: 15px; background: #f9f9f9; } .chat-container { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; background: #f9f9f9; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: with gr.Column(elem_classes="header"): gr.Markdown("# ๐ŸŒ River Pollution Analyzer") gr.Markdown(f"### {model_status}") with gr.Row(elem_classes="side-by-side"): # Left Panel with gr.Column(elem_classes="left-panel"): with gr.Group(): image_input = gr.Image(type="pil", label="Upload River Image", height=300) analyze_btn = gr.Button("๐Ÿ” Analyze Pollution", variant="primary") with gr.Group(elem_classes="analysis-box"): gr.Markdown("### ๐Ÿ“Š Analysis report") analysis_output = gr.Markdown() # Right Panel with gr.Column(elem_classes="right-panel"): with gr.Group(elem_classes="chat-container"): chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400) with gr.Row(): chat_input = gr.Textbox( placeholder="Ask about pollution sources...", label="Your Question", container=False, scale=5 ) chat_btn = gr.Button("๐Ÿ’ฌ Ask", variant="secondary", scale=1) clear_btn = gr.Button("๐Ÿงน Clear Chat History", size="sm") # Connect functions analyze_btn.click( analyzer.analyze_image if analyzer else lambda x: "Model not loaded", inputs=image_input, outputs=analysis_output ) def respond(message, chat_history): if not analyzer: return chat_history + [(message, "Models not loaded. Please try again later.")] response = analyzer.analyze_chat(message) return chat_history + [(message, response)] chat_btn.click( respond, [chat_input, chatbot], [chatbot], ) chat_input.submit( respond, [chat_input, chatbot], [chatbot], ) clear_btn.click(lambda: None, None, chatbot, queue=False) # Update examples to use local files gr.Examples( examples=[ ["examples/polluted_river1.jpg"], ["examples/polluted_river2.jpg"] ], inputs=image_input, outputs=analysis_output, fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded", cache_examples=True, label="Try example images:" ) # Launch with queue for stability and allowed paths demo.queue(max_size=3).launch(allowed_paths=["examples"])