arjunanand13's picture
Update app.py
030841a verified
raw
history blame
6.15 kB
# app.py - Hugging Face Spaces App
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
from transformers import AutoProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
import tempfile
import os
# Model configuration
MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
PEFT_MODEL_ID = "arjunanand13/gas_pipe_llava_finetunedv2"
@torch.no_grad()
def load_model():
"""Load the fine-tuned PEFT model"""
print("Loading PEFT model...")
# Quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.uint8
)
# Load processor
processor = AutoProcessor.from_pretrained(PEFT_MODEL_ID)
# Load base model
base_model = LlavaNextForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
quantization_config=bnb_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True
)
# Load PEFT adapters
model = PeftModel.from_pretrained(base_model, PEFT_MODEL_ID)
print("Model loaded successfully!")
return processor, model
# Load model once at startup
processor, model = load_model()
def preprocess_video(video_path):
"""Extract 4 frames and create 2x2 grid"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = np.linspace(0, max(0, total_frames - 1), 4, dtype=int)
frames = []
for frame_idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb).resize((112, 112))
frames.append(frame_pil)
cap.release()
while len(frames) < 4:
if frames:
frames.append(frames[-1].copy())
else:
frames.append(Image.new('RGB', (112, 112), color='black'))
grid_image = Image.new('RGB', (224, 224))
positions = [(0, 0), (112, 0), (0, 112), (112, 112)]
for i, frame in enumerate(frames[:4]):
grid_image.paste(frame, positions[i])
return grid_image
def predict_cheating(video_file):
"""Analyze video for gas pipe testing compliance"""
if video_file is None:
return "Please upload a video file", None
try:
# Process video
grid_image = preprocess_video(video_file)
# Use forced JSON prompt
prompt = "[INST] <image>\nGas pipe test result? [/INST] {\"cheating\":"
inputs = processor(text=prompt, images=grid_image, return_tensors="pt")
# Move to device
inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
# Generate prediction
generated_ids = model.generate(
**inputs,
max_new_tokens=16,
do_sample=False,
pad_token_id=processor.tokenizer.eos_token_id
)
# Decode result
result = processor.decode(generated_ids[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
# Parse result
full_result = '{"cheating":' + result
if "true" in result.lower():
prediction = "CHEATING DETECTED"
explanation = "The system detected violations in the testing procedure."
status = "🚨"
else:
prediction = "COMPLIANT PROCEDURE"
explanation = "The testing procedure appears to follow proper protocols."
status = "βœ…"
formatted_result = f"{status} **{prediction}**\n\n{explanation}\n\nModel output: {full_result}"
return formatted_result, grid_image
except Exception as e:
return f"Error processing video: {str(e)}", None
# Gradio Interface
with gr.Blocks(
title="Gas Pipe Quality Control Detection",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px;
margin: auto;
}
"""
) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 20px;">
<h1>πŸ”§ Gas Pipe Quality Control Detection</h1>
<p>AI-powered detection of compliance violations in gas pipe testing procedures</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(
label="Upload Gas Pipe Testing Video",
height=320
)
analyze_btn = gr.Button(
"Analyze Video",
variant="primary",
size="lg"
)
gr.Markdown("""
**Supported formats:** MP4, AVI, MOV, WMV
**Max duration:** 60 seconds recommended
""")
with gr.Column(scale=2):
result_text = gr.Textbox(
label="Detection Result",
lines=8,
max_lines=12
)
processed_image = gr.Image(
label="Processed Video Frames (2Γ—2 Grid)",
height=320
)
analyze_btn.click(
fn=predict_cheating,
inputs=[video_input],
outputs=[result_text, processed_image]
)
gr.HTML("""
<div style="text-align: center; margin-top: 20px; padding: 20px; background-color: #f8f9fa; border-radius: 10px;">
<p><strong>Model:</strong> Fine-tuned LLaVA v1.6 Mistral 7B with LoRA adapters</p>
<p><strong>Repository:</strong> <a href="https://huggingface.co/arjunanand13/gas_pipe_llava_finetunedv2" target="_blank">arjunanand13/gas_pipe_llava_finetunedv2</a></p>
<p><strong>Developed by:</strong> Arjun Anand</p>
</div>
""")
if __name__ == "__main__":
demo.launch()