ui-regression-testing-2 / hf_vision_analyzer.py
riazmo's picture
Upload 4 files
a67549f verified
"""
HF Vision Analyzer Module
Uses Hugging Face vision models for semantic image analysis and comparison
"""
import os
from typing import Dict, List, Any, Tuple, Optional
from pathlib import Path
import logging
try:
from transformers import pipeline
from PIL import Image
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from PIL import Image
except ImportError:
Image = Any
logging.warning("Hugging Face transformers not available. Install with: pip install transformers torch")
class HFVisionAnalyzer:
"""
Analyzes images using Hugging Face vision models
Supports multiple analysis types: captioning, classification, object detection
"""
def __init__(self, hf_token: Optional[str] = None, model_type: str = "captioning"):
"""
Initialize HF Vision Analyzer
Args:
hf_token: Hugging Face API token (optional)
model_type: Type of analysis - "captioning", "classification", "detection"
"""
self.hf_token = hf_token or os.getenv("HUGGINGFACE_API_KEY")
self.model_type = model_type
self.pipeline = None
self.analysis_cache = {}
if HF_AVAILABLE:
self._initialize_pipeline()
else:
logging.error("Hugging Face transformers not available")
def _initialize_pipeline(self):
"""Initialize the appropriate HF pipeline"""
try:
pipeline_kwargs = {
"device": 0 if self._has_gpu() else -1
}
if self.hf_token:
pipeline_kwargs["token"] = self.hf_token
if self.model_type == "captioning":
self.pipeline = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
**pipeline_kwargs
)
logging.info("βœ… Initialized image captioning pipeline")
elif self.model_type == "classification":
self.pipeline = pipeline(
"image-classification",
model="google/vit-base-patch16-224",
**pipeline_kwargs
)
logging.info("βœ… Initialized image classification pipeline")
elif self.model_type == "detection":
self.pipeline = pipeline(
"object-detection",
model="facebook/detr-resnet50",
**pipeline_kwargs
)
logging.info("βœ… Initialized object detection pipeline")
except Exception as e:
logging.error(f"Failed to initialize pipeline: {str(e)}")
self.pipeline = None
def _has_gpu(self) -> bool:
"""Check if GPU is available"""
try:
import torch
return torch.cuda.is_available()
except:
return False
def analyze_image(self, image_path: str) -> Dict[str, Any]:
"""
Analyze a single image
Args:
image_path: Path to image file
Returns:
Dictionary with analysis results
"""
if not self.pipeline:
return {"error": "Pipeline not initialized"}
# Check cache
if image_path in self.analysis_cache:
return self.analysis_cache[image_path]
try:
image = Image.open(image_path)
if self.model_type == "captioning":
result = self._analyze_captioning(image)
elif self.model_type == "classification":
result = self._analyze_classification(image)
elif self.model_type == "detection":
result = self._analyze_detection(image)
else:
result = {"error": "Unknown model type"}
# Cache result
self.analysis_cache[image_path] = result
return result
except Exception as e:
logging.error(f"Error analyzing image {image_path}: {str(e)}")
return {"error": str(e)}
def _analyze_captioning(self, image: Image.Image) -> Dict[str, Any]:
"""Image captioning analysis"""
try:
results = self.pipeline(image)
caption = results[0]["generated_text"] if results else "No caption generated"
return {
"type": "captioning",
"caption": caption,
"confidence": 0.85,
"keywords": self._extract_keywords(caption)
}
except Exception as e:
return {"error": str(e)}
def _analyze_classification(self, image: Image.Image) -> Dict[str, Any]:
"""Image classification analysis"""
try:
results = self.pipeline(image)
return {
"type": "classification",
"classes": [
{
"label": r["label"],
"score": r["score"]
}
for r in results[:5] # Top 5 classes
],
"top_class": results[0]["label"] if results else "Unknown"
}
except Exception as e:
return {"error": str(e)}
def _analyze_detection(self, image: Image.Image) -> Dict[str, Any]:
"""Object detection analysis"""
try:
results = self.pipeline(image)
return {
"type": "detection",
"objects": [
{
"label": obj["label"],
"score": obj["score"],
"box": obj["box"]
}
for obj in results
],
"object_count": len(results),
"object_types": list(set(obj["label"] for obj in results))
}
except Exception as e:
return {"error": str(e)}
def _extract_keywords(self, text: str) -> List[str]:
"""Extract keywords from text"""
# Simple keyword extraction (can be enhanced with NLP)
stop_words = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "is", "are"}
words = text.lower().split()
keywords = [w for w in words if w not in stop_words and len(w) > 3]
return list(set(keywords))
def compare_images(self, figma_path: str, website_path: str) -> Dict[str, Any]:
"""
Compare two images using HF vision analysis
Args:
figma_path: Path to Figma screenshot
website_path: Path to website screenshot
Returns:
Comparison results with differences
"""
figma_analysis = self.analyze_image(figma_path)
website_analysis = self.analyze_image(website_path)
if "error" in figma_analysis or "error" in website_analysis:
return {
"error": "Failed to analyze one or both images",
"figma_error": figma_analysis.get("error"),
"website_error": website_analysis.get("error")
}
if self.model_type == "captioning":
return self._compare_captions(figma_analysis, website_analysis)
elif self.model_type == "classification":
return self._compare_classifications(figma_analysis, website_analysis)
elif self.model_type == "detection":
return self._compare_detections(figma_analysis, website_analysis)
return {"error": "Unknown comparison type"}
def _compare_captions(self, figma_analysis: Dict, website_analysis: Dict) -> Dict[str, Any]:
"""Compare image captions"""
figma_caption = figma_analysis.get("caption", "")
website_caption = website_analysis.get("caption", "")
figma_keywords = set(figma_analysis.get("keywords", []))
website_keywords = set(website_analysis.get("keywords", []))
missing_keywords = figma_keywords - website_keywords
extra_keywords = website_keywords - figma_keywords
common_keywords = figma_keywords & website_keywords
# Calculate similarity
if figma_keywords or website_keywords:
similarity = len(common_keywords) / len(figma_keywords | website_keywords)
else:
similarity = 1.0
return {
"comparison_type": "captioning",
"figma_caption": figma_caption,
"website_caption": website_caption,
"similarity_score": similarity * 100,
"missing_elements": list(missing_keywords),
"extra_elements": list(extra_keywords),
"common_elements": list(common_keywords),
"differences_detected": len(missing_keywords) + len(extra_keywords)
}
def _compare_classifications(self, figma_analysis: Dict, website_analysis: Dict) -> Dict[str, Any]:
"""Compare image classifications"""
figma_classes = set(c["label"] for c in figma_analysis.get("classes", []))
website_classes = set(c["label"] for c in website_analysis.get("classes", []))
missing_classes = figma_classes - website_classes
extra_classes = website_classes - figma_classes
common_classes = figma_classes & website_classes
return {
"comparison_type": "classification",
"figma_top_class": figma_analysis.get("top_class"),
"website_top_class": website_analysis.get("top_class"),
"missing_classes": list(missing_classes),
"extra_classes": list(extra_classes),
"common_classes": list(common_classes),
"differences_detected": len(missing_classes) + len(extra_classes)
}
def _compare_detections(self, figma_analysis: Dict, website_analysis: Dict) -> Dict[str, Any]:
"""Compare object detections"""
figma_objects = figma_analysis.get("object_types", [])
website_objects = website_analysis.get("object_types", [])
figma_set = set(figma_objects)
website_set = set(website_objects)
missing_objects = figma_set - website_set
extra_objects = website_set - figma_set
return {
"comparison_type": "detection",
"figma_object_count": figma_analysis.get("object_count", 0),
"website_object_count": website_analysis.get("object_count", 0),
"figma_objects": figma_objects,
"website_objects": website_objects,
"missing_objects": list(missing_objects),
"extra_objects": list(extra_objects),
"differences_detected": len(missing_objects) + len(extra_objects)
}
def generate_difference_report(self, comparison: Dict[str, Any]) -> str:
"""Generate human-readable difference report"""
lines = []
if "error" in comparison:
return f"Error: {comparison['error']}"
comp_type = comparison.get("comparison_type", "unknown")
if comp_type == "captioning":
lines.append("πŸ“Έ Image Captioning Comparison\n")
lines.append(f"Design Caption: {comparison.get('figma_caption', 'N/A')}")
lines.append(f"Website Caption: {comparison.get('website_caption', 'N/A')}")
lines.append(f"Similarity: {comparison.get('similarity_score', 0):.1f}%\n")
if comparison.get("missing_elements"):
lines.append(f"Missing Elements: {', '.join(comparison['missing_elements'])}")
if comparison.get("extra_elements"):
lines.append(f"Extra Elements: {', '.join(comparison['extra_elements'])}")
elif comp_type == "detection":
lines.append("πŸ” Object Detection Comparison\n")
lines.append(f"Design Objects: {comparison.get('figma_object_count', 0)}")
lines.append(f"Website Objects: {comparison.get('website_object_count', 0)}\n")
if comparison.get("missing_objects"):
lines.append(f"Missing Objects: {', '.join(comparison['missing_objects'])}")
if comparison.get("extra_objects"):
lines.append(f"Extra Objects: {', '.join(comparison['extra_objects'])}")
return "\n".join(lines)
def create_hf_analyzer(hf_token: Optional[str] = None, model_type: str = "captioning") -> Optional[HFVisionAnalyzer]:
"""
Factory function to create HF Vision Analyzer
Args:
hf_token: Hugging Face API token
model_type: Type of analysis model
Returns:
HFVisionAnalyzer instance or None if HF not available
"""
if not HF_AVAILABLE:
logging.warning("Hugging Face not available. Install with: pip install transformers torch")
return None
return HFVisionAnalyzer(hf_token=hf_token, model_type=model_type)