Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
| import gradio as gr | |
| from PIL import Image | |
| import re | |
| from typing import List, Tuple | |
| # Configuration | |
| MODEL_NAME = "Salesforce/instructblip-flan-t5-xl" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| class RiverPollutionAnalyzer: | |
| def __init__(self): | |
| # Initialize processor and model | |
| self.processor = InstructBlipProcessor.from_pretrained(MODEL_NAME) | |
| self.model = InstructBlipForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| self.pollutants = [ | |
| "plastic waste", "chemical foam", "industrial discharge", | |
| "sewage water", "oil spill", "organic debris", | |
| "construction waste", "medical waste", "floating trash", | |
| "algal bloom", "toxic sludge", "agricultural runoff" | |
| ] | |
| self.severity_descriptions = { | |
| 1: "Minimal pollution - Slightly noticeable", | |
| 2: "Minor pollution - Small amounts visible", | |
| 3: "Moderate pollution - Clearly visible", | |
| 4: "Significant pollution - Affecting water quality", | |
| 5: "Heavy pollution - Obvious environmental impact", | |
| 6: "Severe pollution - Large accumulation", | |
| 7: "Very severe pollution - Major ecosystem impact", | |
| 8: "Extreme pollution - Dangerous levels", | |
| 9: "Critical pollution - Immediate action needed", | |
| 10: "Disaster level - Ecological catastrophe" | |
| } | |
| def analyze_image(self, image): | |
| """Analyze river pollution with robust parsing""" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| prompt = """Analyze this river pollution scene and provide: | |
| 1. List ALL visible pollutants ONLY from: [plastic waste, chemical foam, industrial discharge, sewage water, oil spill, organic debris, construction waste, medical waste, floating trash, algal bloom, toxic sludge, agricultural runoff] | |
| 2. Estimate pollution severity from 1-10 | |
| Respond EXACTLY in this format: | |
| Pollutants: [comma separated list] | |
| Severity: [number]""" | |
| inputs = self.processor( | |
| images=image, | |
| text=prompt, | |
| return_tensors="pt" | |
| ).to(DEVICE, TORCH_DTYPE) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| analysis = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| pollutants, severity = self._parse_response(analysis) | |
| return self._format_analysis(pollutants, severity) | |
| def analyze_chat(self, message): | |
| """Handle chat questions about pollution""" | |
| if "severity" in message.lower(): | |
| return "Severity levels range from 1 (minimal) to 10 (disaster). The analyzer automatically detects the appropriate level." | |
| elif "pollutant" in message.lower(): | |
| return f"Detectable pollutants: {', '.join(self.pollutants)}" | |
| else: | |
| return "I can answer questions about pollution severity levels and detectable pollutants." | |
| def _parse_response(self, analysis: str) -> Tuple[List[str], int]: | |
| """Robust parsing of model response""" | |
| pollutants = [] | |
| severity = 3 | |
| # Extract pollutants | |
| pollutant_match = re.search( | |
| r'Pollutants:\s*\[?(.*?)\]?', | |
| analysis, re.IGNORECASE | |
| ) | |
| if pollutant_match: | |
| pollutants_str = pollutant_match.group(1).strip() | |
| pollutants = [ | |
| p.strip().lower() | |
| for p in re.split(r'[,;]', pollutants_str) | |
| if p.strip().lower() in self.pollutants | |
| ] | |
| # Extract severity | |
| severity_match = re.search( | |
| r'Severity:\s*(\d{1,2})', | |
| analysis, re.IGNORECASE | |
| ) | |
| if severity_match: | |
| severity = min(max(int(severity_match.group(1)), 1), 10) | |
| else: | |
| severity = self._calculate_severity(pollutants) | |
| return pollutants, severity | |
| def _calculate_severity(self, pollutants: List[str]) -> int: | |
| """Weighted severity calculation""" | |
| if not pollutants: | |
| return 1 | |
| weights = { | |
| "medical waste": 3, "toxic sludge": 3, "oil spill": 2.5, | |
| "chemical foam": 2, "industrial discharge": 2, "sewage water": 2, | |
| "plastic waste": 1.5, "construction waste": 1.5, "algal bloom": 1.5, | |
| "agricultural runoff": 1.5, "floating trash": 1, "organic debris": 1 | |
| } | |
| avg_weight = sum(weights.get(p, 1) for p in pollutants) / len(pollutants) | |
| return min(10, max(1, round(avg_weight * 3))) | |
| def _format_analysis(self, pollutants: List[str], severity: int) -> str: | |
| """Generate formatted report""" | |
| severity_bar = f"""π Severity: {severity}/10 | |
| {"β" * severity}{"β" * (10 - severity)} | |
| {self.severity_descriptions.get(severity, '')}""" | |
| pollutants_list = "\nπ No pollutants detected" if not pollutants else "\n".join( | |
| f"β’ {p.capitalize()}" for p in pollutants[:8]) | |
| return f"""π River Pollution Analysis π | |
| {pollutants_list} | |
| {severity_bar}""" | |
| # Initialize analyzer | |
| analyzer = RiverPollutionAnalyzer() | |
| # Gradio Interface | |
| css = """ | |
| .header { text-align: center; margin-bottom: 20px; } | |
| .header h1 { font-size: 2.2rem; margin-bottom: 0; } | |
| .header h3 { font-size: 1.1rem; font-weight: normal; margin-top: 0.5rem; } | |
| .side-by-side { display: flex; gap: 20px; } | |
| .left-panel, .right-panel { flex: 1; } | |
| .analysis-box { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; margin-top: 20px; } | |
| .chat-container { border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px; height: 100%; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| with gr.Column(elem_classes="header"): | |
| gr.Markdown("# π River Pollution Analyzer") | |
| gr.Markdown("### AI-powered water quality assessment") | |
| with gr.Row(elem_classes="side-by-side"): | |
| # Image Analysis Panel | |
| with gr.Column(elem_classes="left-panel"): | |
| gr.Markdown("### πΈ Image Analysis") | |
| with gr.Group(): | |
| image_input = gr.Image(type="pil", label="Upload River Image", height=300) | |
| analyze_btn = gr.Button("π Analyze", variant="primary") | |
| with gr.Group(elem_classes="analysis-box"): | |
| analysis_output = gr.Markdown() | |
| # Chat Panel | |
| with gr.Column(elem_classes="right-panel"): | |
| gr.Markdown("### π¬ Pollution Q&A") | |
| with gr.Group(elem_classes="chat-container"): | |
| chatbot = gr.Chatbot(height=350) | |
| with gr.Row(): | |
| chat_input = gr.Textbox(placeholder="Ask about pollution...", show_label=False) | |
| chat_btn = gr.Button("Send", variant="secondary") | |
| clear_btn = gr.Button("Clear Chat") | |
| # Event handlers | |
| analyze_btn.click( | |
| analyzer.analyze_image, | |
| inputs=image_input, | |
| outputs=analysis_output | |
| ) | |
| def respond(message, chat_history): | |
| response = analyzer.analyze_chat(message) | |
| chat_history.append((message, response)) | |
| return "", chat_history | |
| chat_input.submit(respond, [chat_input, chatbot], [chat_input, chatbot]) | |
| chat_btn.click(respond, [chat_input, chatbot], [chat_input, chatbot]) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| # Examples | |
| gr.Examples( | |
| examples=[["examples/pollution1.jpg"], ["examples/pollution2.jpg"]], | |
| inputs=image_input, | |
| outputs=analysis_output, | |
| fn=analyzer.analyze_image, | |
| cache_examples=True, | |
| label="Example Images" | |
| ) | |
| demo.launch() |