Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| BitsAndBytesConfig | |
| ) | |
| import gradio as gr | |
| from PIL import Image | |
| import re | |
| import os | |
| from typing import List, Tuple | |
| # Create cache directory | |
| os.makedirs("model_cache", exist_ok=True) | |
| os.makedirs("examples", exist_ok=True) # Create examples directory | |
| # Configuration for 4-bit quantization | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| class RiverPollutionAnalyzer: | |
| def __init__(self): | |
| try: | |
| # Initialize BLIP for image captioning with caching | |
| self.blip_processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| cache_dir="model_cache" | |
| ) | |
| self.blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| cache_dir="model_cache" | |
| ) | |
| # Initialize FLAN-T5-XL with quantization | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "google/flan-t5-xl", | |
| cache_dir="model_cache" | |
| ) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "google/flan-t5-xl", | |
| device_map="auto", | |
| quantization_config=quant_config, | |
| cache_dir="model_cache" | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| self.pollutants = [ | |
| "plastic waste", "chemical foam", "industrial discharge", | |
| "sewage water", "oil spill", "organic debris", | |
| "construction waste", "medical waste", "floating trash", | |
| "algal bloom", "toxic sludge", "agricultural runoff" | |
| ] | |
| self.severity_descriptions = { | |
| 1: "Minimal pollution - Slightly noticeable", | |
| 2: "Minor pollution - Small amounts visible", | |
| 3: "Moderate pollution - Clearly visible", | |
| 4: "Significant pollution - Affecting water quality", | |
| 5: "Heavy pollution - Obvious environmental impact", | |
| 6: "Severe pollution - Large accumulation", | |
| 7: "Very severe pollution - Major ecosystem impact", | |
| 8: "Extreme pollution - Dangerous levels", | |
| 9: "Critical pollution - Immediate action needed", | |
| 10: "Disaster level - Ecological catastrophe" | |
| } | |
| def analyze_image(self, image): | |
| """Two-step analysis: BLIP captioning + FLAN-T5 analysis""" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| try: | |
| # Step 1: Generate image caption with BLIP | |
| inputs = self.blip_processor(image, return_tensors="pt").to(self.blip_model.device, torch.float16) | |
| caption = self.blip_model.generate(**inputs, max_new_tokens=100)[0] | |
| caption = self.blip_processor.decode(caption, skip_special_tokens=True) | |
| # Step 2: Analyze caption with FLAN-T5 | |
| prompt = f"""Analyze this river scene: '{caption}' | |
| 1. List visible pollutants from: {self.pollutants} | |
| 2. Estimate severity (1-10) | |
| Respond EXACTLY as: | |
| Pollutants: [comma separated list] | |
| Severity: [number]""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=200) | |
| analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| pollutants, severity = self._parse_response(analysis) | |
| return self._format_analysis(pollutants, severity) | |
| except Exception as e: | |
| return f"β οΈ Analysis failed: {str(e)}" | |
| def _parse_response(self, analysis: str) -> Tuple[List[str], int]: | |
| """Parse the model response into pollutants list and severity score""" | |
| pollutants = [] | |
| severity = 0 | |
| # Extract pollutants | |
| pollutants_match = re.search(r"Pollutants:\s*\[(.*?)\]", analysis) | |
| if pollutants_match: | |
| pollutants_str = pollutants_match.group(1) | |
| pollutants = [p.strip() for p in pollutants_str.split(",") if p.strip()] | |
| # Extract severity | |
| severity_match = re.search(r"Severity:\s*(\d+)", analysis) | |
| if severity_match: | |
| severity = int(severity_match.group(1)) | |
| # If parsing failed, fallback to calculating severity | |
| if not severity or severity < 1 or severity > 10: | |
| severity = self._calculate_severity(pollutants) | |
| return pollutants, severity | |
| def _calculate_severity(self, pollutants: List[str]) -> int: | |
| """Calculate severity based on pollutants""" | |
| if not pollutants: | |
| return 1 | |
| severity_map = { | |
| "plastic waste": 4, | |
| "chemical foam": 7, | |
| "industrial discharge": 8, | |
| "sewage water": 6, | |
| "oil spill": 9, | |
| "organic debris": 3, | |
| "construction waste": 5, | |
| "medical waste": 8, | |
| "floating trash": 4, | |
| "algal bloom": 6, | |
| "toxic sludge": 9, | |
| "agricultural runoff": 5 | |
| } | |
| base_score = sum(severity_map.get(p, 3) for p in pollutants) | |
| avg_score = base_score / len(pollutants) | |
| return min(10, max(1, round(avg_score))) | |
| def _format_analysis(self, pollutants: List[str], severity: int) -> str: | |
| """Format the analysis results into a markdown report""" | |
| if not pollutants: | |
| pollutants = ["No visible pollution detected"] | |
| pollutants_list = "\n".join(f"- {p}" for p in pollutants) | |
| severity_desc = self.severity_descriptions.get(severity, "Unknown severity level") | |
| return f""" | |
| ## Pollution Analysis Report | |
| ### Identified Pollutants: | |
| {pollutants_list} | |
| ### Severity Assessment: | |
| **Level {severity}/10** - {severity_desc} | |
| ### Recommended Actions: | |
| {self._get_recommendations(severity)} | |
| """ | |
| def _get_recommendations(self, severity: int) -> str: | |
| """Get recommendations based on severity level""" | |
| if severity <= 3: | |
| return "Monitor the situation. Consider community clean-up efforts." | |
| elif severity <= 5: | |
| return "Local authorities should investigate. Basic remediation needed." | |
| elif severity <= 7: | |
| return "Immediate containment required. Environmental assessment needed." | |
| elif severity <= 9: | |
| return "Emergency response required. Notify environmental agencies." | |
| else: | |
| return "Disaster response needed. Evacuation may be necessary." | |
| def analyze_chat(self, message: str) -> str: | |
| """Handle chat questions about pollution""" | |
| prompt = f"""You are an environmental expert. Answer this question about river pollution: {message} | |
| Provide a concise, factual response in under 100 words.""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=150) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Initialize with error handling | |
| try: | |
| analyzer = RiverPollutionAnalyzer() | |
| model_status = "β Models loaded successfully" | |
| except Exception as e: | |
| analyzer = None | |
| model_status = f"β Model loading failed: {str(e)}" | |
| # Gradio Interface | |
| css = """ | |
| .header { | |
| text-align: center; | |
| max-width: 800px; | |
| margin: auto; | |
| } | |
| .header img { | |
| max-width: 100%; | |
| } | |
| .side-by-side { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 20px; | |
| } | |
| .left-panel, .right-panel { | |
| flex: 1; | |
| min-width: 300px; | |
| } | |
| .analysis-box { | |
| border: 1px solid #e0e0e0; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-top: 15px; | |
| background: #f9f9f9; | |
| } | |
| .chat-container { | |
| border: 1px solid #e0e0e0; | |
| border-radius: 8px; | |
| padding: 15px; | |
| background: #f9f9f9; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| with gr.Column(elem_classes="header"): | |
| gr.Markdown("# π River Pollution Analyzer") | |
| gr.Markdown(f"### {model_status}") | |
| with gr.Row(elem_classes="side-by-side"): | |
| # Left Panel | |
| with gr.Column(elem_classes="left-panel"): | |
| with gr.Group(): | |
| image_input = gr.Image(type="pil", label="Upload River Image", height=300) | |
| analyze_btn = gr.Button("π Analyze Pollution", variant="primary") | |
| with gr.Group(elem_classes="analysis-box"): | |
| gr.Markdown("### π Analysis report") | |
| analysis_output = gr.Markdown() | |
| # Right Panel | |
| with gr.Column(elem_classes="right-panel"): | |
| with gr.Group(elem_classes="chat-container"): | |
| chatbot = gr.Chatbot(label="Pollution Analysis Q&A", height=400) | |
| with gr.Row(): | |
| chat_input = gr.Textbox( | |
| placeholder="Ask about pollution sources...", | |
| label="Your Question", | |
| container=False, | |
| scale=5 | |
| ) | |
| chat_btn = gr.Button("π¬ Ask", variant="secondary", scale=1) | |
| clear_btn = gr.Button("π§Ή Clear Chat History", size="sm") | |
| # Connect functions | |
| analyze_btn.click( | |
| analyzer.analyze_image if analyzer else lambda x: "Model not loaded", | |
| inputs=image_input, | |
| outputs=analysis_output | |
| ) | |
| def respond(message, chat_history): | |
| if not analyzer: | |
| return chat_history + [(message, "Models not loaded. Please try again later.")] | |
| response = analyzer.analyze_chat(message) | |
| return chat_history + [(message, response)] | |
| chat_btn.click( | |
| respond, | |
| [chat_input, chatbot], | |
| [chatbot], | |
| ) | |
| chat_input.submit( | |
| respond, | |
| [chat_input, chatbot], | |
| [chatbot], | |
| ) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| # Update examples to use local files | |
| gr.Examples( | |
| examples=[ | |
| ["examples/polluted_river1.jpg"], | |
| ["examples/polluted_river2.jpg"] | |
| ], | |
| inputs=image_input, | |
| outputs=analysis_output, | |
| fn=analyzer.analyze_image if analyzer else lambda x: "Model not loaded", | |
| cache_examples=True, | |
| label="Try example images:" | |
| ) | |
| # Launch with queue for stability and allowed paths | |
| demo.queue(max_size=3).launch(allowed_paths=["examples"]) |