""" Style Model Module Handles caption styling using Groq API with fallback mechanisms. Applies different writing styles to generated captions. """ import time from typing import Optional from groq import Groq import requests from config import groq_config, style_config class StyleModelError(Exception): """Custom exception for style model errors""" pass class StyleModel: """ Caption styling using Groq LLM API Features: - Multiple style options - Automatic retry logic - Fallback to rule-based styling - Rate limiting handling """ def __init__(self, api_key: Optional[str] = None): """ Initialize style model Args: api_key: Groq API key (uses config if not provided) """ self.api_key = api_key or groq_config.API_KEY self.model_name = groq_config.MODEL_NAME self.max_tokens = groq_config.MAX_TOKENS self.temperature = groq_config.TEMPERATURE self.timeout = groq_config.TIMEOUT_SECONDS # Initialize Groq client if self.api_key: try: self.client = Groq( api_key=self.api_key ) self._api_available = True _ = self.client.models.list() except Exception as e: print(f"Warning: Groq client initialization failed: {e}") print(f"Attempting alternative initialization...") try: # Alternative: Create client without extra params import groq self.client = groq.Client(api_key=self.api_key) self._api_available = True except Exception as e2: print(f"Alternative initialization also failed: {e2}") self.client = None self._api_available = False else: print("Warning: No Groq API key provided") self.client = None self._api_available = False # Retry configuration self.max_retries = groq_config.MAX_RETRIES self.retry_delay = groq_config.RETRY_DELAY_SECONDS def style_caption( self, caption: str, style: str = "Professional" ) -> str: """ Apply style to caption Args: caption: Original caption style: Style to apply Returns: str: Styled caption """ # If "None" style or no API, return original if style == "None" or not self._api_available: if style != "None": # Use fallback styling if API unavailable return self._fallback_style(caption, style) return caption # Try API styling with retries for attempt in range(self.max_retries): try: styled_caption = self._style_with_api(caption, style) return styled_caption except Exception as e: print(f"API styling attempt {attempt + 1} failed: {e}") # If last attempt, use fallback if attempt == self.max_retries - 1: print(f"Using fallback styling for: {style}") return self._fallback_style(caption, style) # Wait before retry time.sleep(self.retry_delay) # Fallback if all retries failed return self._fallback_style(caption, style) def _style_with_api(self, caption: str, style: str) -> str: """ Style caption using Groq API Args: caption: Original caption style: Style to apply Returns: str: Styled caption Raises: StyleModelError: If API call fails """ if not self._api_available: raise StyleModelError("API not available") # Get style prompt style_prompt = style_config.STYLES.get( style, style_config.STYLES[style_config.DEFAULT_STYLE] ) # Construct messages messages = [ { "role": "system", "content": "You are an expert at rewriting image captions in different styles. Keep the core meaning but adapt the tone and style as requested. Be concise." }, { "role": "user", "content": f"{style_prompt}\n\nOriginal caption: {caption}\n\nStyled caption:" } ] try: # Make API call response = self.client.chat.completions.create( model=self.model_name, messages=messages, max_tokens=self.max_tokens, temperature=self.temperature, top_p=groq_config.TOP_P, timeout=self.timeout ) # Extract styled caption styled_caption = response.choices[0].message.content.strip() # Clean up common artifacts styled_caption = self._clean_response(styled_caption) return styled_caption except requests.exceptions.Timeout: raise StyleModelError("API request timed out") except requests.exceptions.RequestException as e: raise StyleModelError(f"API request failed: {e}") except Exception as e: raise StyleModelError(f"Unexpected error: {e}") def _fallback_style(self, caption: str, style: str) -> str: """ Apply rule-based styling as fallback Args: caption: Original caption style: Style to apply Returns: str: Styled caption using templates """ template = style_config.FALLBACK_TEMPLATES.get( style, style_config.FALLBACK_TEMPLATES["Professional"] ) return template.format(caption=caption) def _clean_response(self, text: str) -> str: """ Clean up API response Args: text: Raw response text Returns: str: Cleaned text """ # Remove common prefixes prefixes = [ "Styled caption:", "Caption:", "Here's the styled caption:", "Here is the caption:", ] for prefix in prefixes: if text.lower().startswith(prefix.lower()): text = text[len(prefix):].strip() # Remove quotes if the entire text is quoted if (text.startswith('"') and text.endswith('"')) or \ (text.startswith("'") and text.endswith("'")): text = text[1:-1] return text.strip() def batch_style_captions( self, captions: dict, style: str = "Professional" ) -> dict: """ Style multiple captions at once Args: captions: Dictionary of {model_name: caption} style: Style to apply Returns: dict: Dictionary of {model_name: styled_caption} """ styled_captions = {} for model_name, caption in captions.items(): try: styled_caption = self.style_caption(caption, style) styled_captions[model_name] = styled_caption except Exception as e: print(f"Error styling {model_name} caption: {e}") # Use original caption on error styled_captions[model_name] = caption return styled_captions def is_api_available(self) -> bool: """Check if API is available""" return self._api_available def test_connection(self) -> bool: """ Test API connection Returns: bool: True if API is working """ if not self._api_available: return False try: # Simple test call response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "user", "content": "Hello"} ], max_tokens=10, timeout=5 ) return True except Exception as e: print(f"API connection test failed: {e}") return False def get_available_styles(self) -> list: """Get list of available styles""" return list(style_config.STYLES.keys()) def get_info(self) -> dict: """Get model information""" return { "model_name": self.model_name, "api_available": self._api_available, "max_tokens": self.max_tokens, "temperature": self.temperature, "available_styles": self.get_available_styles() } # Singleton instance _style_model = None def get_style_model() -> StyleModel: """Get singleton StyleModel instance""" global _style_model if _style_model is None: _style_model = StyleModel() return _style_model if __name__ == "__main__": # Test the style model print("=" * 60) print("STYLE MODEL - TEST MODE") print("=" * 60) # Initialize model style_model = StyleModel() print(f"\n✓ Style model initialized") print(f" API Available: {style_model.is_api_available()}") print(f" Model: {style_model.model_name}") # Get info print("\nModel Info:") info = style_model.get_info() for key, value in info.items(): if isinstance(value, list): print(f" {key}:") for item in value: print(f" - {item}") else: print(f" {key}: {value}") # Test connection if API available if style_model.is_api_available(): print("\nTesting API connection...") connection_ok = style_model.test_connection() print(f" Connection: {'✓ Success' if connection_ok else '✗ Failed'}") if connection_ok: # Test styling print("\nTesting caption styling:") test_caption = "A cat sitting on a windowsill looking outside" for style in ["Professional", "Creative", "Social Media"]: print(f"\n {style}:") try: styled = style_model.style_caption(test_caption, style) print(f" Original: {test_caption}") print(f" Styled: {styled}") except Exception as e: print(f" Error: {e}") else: print("\n⚠️ API not available, testing fallback styling:") test_caption = "A cat sitting on a windowsill looking outside" for style in ["Professional", "Creative", "Social Media"]: styled = style_model.style_caption(test_caption, style) print(f"\n {style}: {styled}") print("\n" + "=" * 60) print("✓ Style model test complete") print("=" * 60)