stellar-search / label_generation_service.py
X1ng1's picture
fixed imports
95a501f
"""
Label generation service using Hugging Face Inference API
"""
import os
from typing import List, Optional
import logging
from huggingface_hub import InferenceClient
from config import config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LabelGenerationService:
"""Service for generating human-readable labels for clusters using HF Inference API"""
def __init__(self, model_name: Optional[str] = None):
"""
Initialize label generation service
Args:
model_name: Name of the HuggingFace model to use
"""
self.model_name = model_name or config.LLM_MODEL
self.api_key = config.HF_TOKEN
if not self.api_key:
logger.warning("HF_TOKEN not set! Label generation will use fallback methods.")
self.client = None
else:
logger.info(f"Initializing HuggingFace Inference API client for model: {self.model_name}")
self.client = InferenceClient(api_key=self.api_key)
logger.info("Label generation service initialized successfully")
def generate_cluster_label(
self,
messages: List[str],
max_messages: int = 10,
max_length: int = 50
) -> str:
"""
Generate a descriptive label for a cluster of messages
Args:
messages: List of message texts in the cluster
max_messages: Maximum number of messages to include in prompt
max_length: Maximum length of generated label
Returns:
Generated label string
"""
if not messages:
return "Empty Cluster"
# Select representative messages
selected_messages = messages[:max_messages]
# If no API client, use fallback
if not self.client:
logger.info("No API client available, using fallback label generation")
return self._fallback_label(selected_messages)
# Create prompt for label generation
prompt = self._create_label_prompt(selected_messages)
try:
# Call HuggingFace Inference API
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": prompt
}
],
max_tokens=max_length,
temperature=0.7,
)
label = response.choices[0].message.content.strip()
# Clean up label
label = self._clean_label(label)
return label
except Exception as e:
logger.error(f"Error generating label via API: {e}")
logger.info("Falling back to frequency-based labeling")
return self._fallback_label(selected_messages)
def generate_tags(
self,
messages: List[str],
max_messages: int = 10,
num_tags: int = 3
) -> List[str]:
"""
Generate topic tags for a cluster
Args:
messages: List of message texts in the cluster
max_messages: Maximum number of messages to include
num_tags: Number of tags to generate
Returns:
List of generated tags
"""
if not messages:
return []
selected_messages = messages[:max_messages]
# If no API client, use fallback
if not self.client:
return self._fallback_tags(selected_messages, num_tags)
# Create prompt for tag generation
prompt = self._create_tags_prompt(selected_messages, num_tags)
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": prompt
}
],
max_tokens=100,
temperature=0.7,
)
tags_text = response.choices[0].message.content.strip()
# Parse tags from output
tags = self._parse_tags(tags_text)
return tags[:num_tags]
except Exception as e:
logger.error(f"Error generating tags via API: {e}")
return self._fallback_tags(selected_messages, num_tags)
def _create_label_prompt(self, messages: List[str]) -> str:
"""Create prompt for label generation"""
messages_text = "\n".join([f"- {msg[:150]}" for msg in messages])
prompt = f"""Analyze these chat messages and create a clear, descriptive topic label in 3-6 words.
The label should capture the main theme or subject of the conversation.
Be specific and avoid generic terms.
Messages:
{messages_text}
Create a descriptive topic label (3-6 words):"""
return prompt
def _create_tags_prompt(self, messages: List[str], num_tags: int) -> str:
"""Create prompt for tag generation"""
messages_text = "\n".join([f"- {msg[:150]}" for msg in messages])
prompt = f"""Analyze these chat messages and generate {num_tags} specific topic keywords.
Use concrete terms that describe the main subjects discussed.
Messages:
{messages_text}
Generate {num_tags} keywords (comma-separated):"""
return prompt
def _clean_label(self, label: str) -> str:
"""Clean and format generated label"""
# Remove common artifacts
label = label.replace("Topic:", "").replace("topic:", "")
label = label.strip()
# Capitalize first letter
if label:
label = label[0].upper() + label[1:]
# Truncate if too long
if len(label) > 50:
label = label[:47] + "..."
return label or "General Discussion"
def _parse_tags(self, tags_text: str) -> List[str]:
"""Parse tags from generated text"""
# Remove common prefixes
tags_text = tags_text.replace("Keywords:", "").replace("keywords:", "")
tags_text = tags_text.strip()
# Split by comma or newline
tags = [tag.strip().lower() for tag in tags_text.replace("\n", ",").split(",")]
# Filter empty tags and clean
tags = [self._clean_tag(tag) for tag in tags if tag]
return [tag for tag in tags if tag]
def _clean_tag(self, tag: str) -> str:
"""Clean a single tag"""
# Remove special characters, keep alphanumeric and hyphens
tag = "".join(c if c.isalnum() or c in ["-", " "] else "" for c in tag)
tag = tag.strip()
# Replace spaces with hyphens
tag = "-".join(tag.split())
return tag
def _fallback_label(self, messages: List[str]) -> str:
"""Generate fallback label using simple heuristics"""
# Use most common words
from collections import Counter
words = []
for msg in messages:
words.extend(msg.lower().split())
# Expanded stopwords list
stopwords = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "is", "are", "was", "were", "be", "been", "have", "has",
"had", "do", "does", "did", "will", "would", "could", "should", "i", "you",
"he", "she", "it", "we", "they", "me", "him", "her", "us", "them", "my",
"your", "his", "her", "its", "our", "their", "this", "that", "these", "those",
"can", "can't", "don't", "doesn't", "didn't", "won't", "wouldn't", "shouldn't",
"what", "when", "where", "who", "why", "how", "just", "so", "if", "about"}
words = [w for w in words if w not in stopwords and len(w) > 3]
if not words:
return "General Discussion"
# Get top 4 most common words for a more descriptive label
common_words = Counter(words).most_common(4)
label = " ".join([word for word, _ in common_words])
# Capitalize properly
label = " ".join(word.capitalize() for word in label.split())
return label or "General Discussion"
def _fallback_tags(self, messages: List[str], num_tags: int) -> List[str]:
"""Generate fallback tags using simple heuristics"""
from collections import Counter
words = []
for msg in messages:
words.extend(msg.lower().split())
# Filter stopwords
stopwords = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "is", "are", "was", "were", "be", "been", "have", "has",
"had", "do", "does", "did", "will", "would", "could", "should"}
words = [w for w in words if w not in stopwords and len(w) > 3]
if not words:
return ["general"]
common_words = Counter(words).most_common(num_tags * 2)
tags = [self._clean_tag(word) for word, _ in common_words]
return [tag for tag in tags if tag][:num_tags]
# Global instance
_label_service: Optional[LabelGenerationService] = None
def get_label_service() -> LabelGenerationService:
"""Get or create the global label generation service instance"""
global _label_service
if _label_service is None:
_label_service = LabelGenerationService()
return _label_service