File size: 5,613 Bytes
2380f6f
 
 
 
 
 
a453c29
2380f6f
 
870d2ba
2380f6f
 
 
a453c29
 
 
 
7218dd0
2380f6f
 
a453c29
2380f6f
 
 
ce133a0
e2251fd
 
870d2ba
ce133a0
 
 
870d2ba
2380f6f
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
870d2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2380f6f
 
 
 
a453c29
2380f6f
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
 
a453c29
 
 
 
 
 
 
 
 
 
2380f6f
 
 
 
 
 
 
 
 
a453c29
2380f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Service for topic extraction from text using LangChain Groq"""

import logging
from typing import Optional, List
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_groq import ChatGroq
from pydantic import BaseModel, Field
from langsmith import traceable

from config import GROQ_API_KEY, GROQ_TOPIC_MODEL

logger = logging.getLogger(__name__)


class TopicOutput(BaseModel):
    """Pydantic schema for topic extraction output"""
    topic: str = Field(..., description="A specific, detailed topic description")


class TopicService:
    """Service for extracting topics from text arguments"""
    
    def __init__(self):
        self.llm = None
        # Use valid Groq model - defaults from config, fallback to stable model
        self.model_name = GROQ_TOPIC_MODEL if GROQ_TOPIC_MODEL else "llama3-70b-8192"
        # Fallback models to try if primary fails (using current/available Groq models)
        self.fallback_models = [
            "llama3-70b-8192",  # Stable production model (same as chat)
            "llama-3.1-8b-instant",  # Faster, smaller alternative
            "openai/gpt-oss-20b"  # Alternative OpenAI OSS model
        ]
        self.initialized = False
        
    def initialize(self, model_name: Optional[str] = None):
        """Initialize the Groq LLM with structured output"""
        if self.initialized:
            logger.info("Topic service already initialized")
            return
            
        if not GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY not found in environment variables")
        
        if model_name:
            self.model_name = model_name
            
        # Try primary model first, then fallbacks
        models_to_try = [self.model_name] + [m for m in self.fallback_models if m != self.model_name]
        
        last_error = None
        for model_to_try in models_to_try:
            try:
                logger.info(f"Initializing topic extraction service with model: {model_to_try}")
                
                llm = ChatGroq(
                    model=model_to_try,
                    api_key=GROQ_API_KEY,
                    temperature=0.0,
                    max_tokens=512,
                )
                
                # Bind structured output directly to the model
                self.llm = llm.with_structured_output(TopicOutput)
                self.model_name = model_to_try  # Update to successful model
                self.initialized = True
                
                logger.info(f"✓ Topic extraction service initialized successfully with model: {model_to_try}")
                return
                
            except Exception as e:
                last_error = e
                logger.warning(f"Failed to initialize with model {model_to_try}: {str(e)}")
                continue
        
        # If all models failed
        logger.error(f"Error initializing topic service with all models: {last_error}")
        raise RuntimeError(f"Failed to initialize topic service with any model. Last error: {str(last_error)}")
    
    @traceable(name="extract_topic")
    def extract_topic(self, text: str) -> str:
        """
        Extract a topic from the given text/argument
        
        Args:
            text: The input text/argument to extract topic from
            
        Returns:
            The extracted topic string
        """
        if not self.initialized:
            self.initialize()
        
        if not text or not isinstance(text, str):
            raise ValueError("Text must be a non-empty string")
        
        text = text.strip()
        if len(text) == 0:
            raise ValueError("Text cannot be empty")
        
        system_message = """You are an information extraction model.
Extract a topic from the user text. The topic should be a single sentence that captures the main idea of the text in simple english.

Examples:
- Text: "Governments should subsidize electric cars to encourage adoption."
  Output: topic="government subsidies for electric vehicle adoption"

- Text: "Raising the minimum wage will hurt small businesses and cost jobs."
  Output: topic="raising the minimum wage and its economic impact on small businesses"
"""
        
        try:
            result = self.llm.invoke(
                [
                    SystemMessage(content=system_message),
                    HumanMessage(content=text),
                ]
            )
            
            return result.topic
            
        except Exception as e:
            logger.error(f"Error extracting topic: {str(e)}")
            raise RuntimeError(f"Topic extraction failed: {str(e)}")
    
    def batch_extract_topics(self, texts: List[str]) -> List[str]:
        """
        Extract topics from multiple texts
        
        Args:
            texts: List of input texts/arguments
            
        Returns:
            List of extracted topics
        """
        if not self.initialized:
            self.initialize()
        
        if not texts or not isinstance(texts, list):
            raise ValueError("Texts must be a non-empty list")
        
        results = []
        for text in texts:
            try:
                topic = self.extract_topic(text)
                results.append(topic)
            except Exception as e:
                logger.error(f"Error extracting topic for text '{text[:50]}...': {str(e)}")
                results.append(None)  # Or raise, depending on desired behavior
        
        return results


# Initialize singleton instance
topic_service = TopicService()