VcRlAgent's picture
Generator Refactor HF Inference Client API
a1544bb
"""LLM generation service using Hugging Face Inference Client SDK"""
import os
from typing import Optional
from huggingface_hub import InferenceClient
from app.config import settings
from app.utils.logger import setup_logger
logger = setup_logger(__name__)
class GeneratorService:
"""Handles text generation using Hugging Face InferenceClient"""
def __init__(self):
# Create a single reusable inference client
self.client = InferenceClient(api_key=settings.HF_TOKEN)
# Use model from settings or fallback
self.model = getattr(settings, "HF_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
def generate(
self,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
) -> str:
"""Generate text using HF chat-completion API"""
try:
logger.info(f"Calling HF InferenceClient (model={self.model})...")
completion = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
)
generated_text = completion.choices[0].message.content
logger.info("Generation successful")
return generated_text.strip()
except Exception as e:
logger.error(f"HF Generation failed: {str(e)}")
return self._fallback_response(prompt)
def _fallback_response(self, prompt: str) -> str:
"""Fallback response when LLM API fails"""
return (
"I apologize, but I'm unable to generate a response at the moment. "
"Please try again later."
)
def generate_rag_response(self, query: str, context: str) -> str:
"""Generate response using RAG-style prompt formatting"""
prompt = self._build_rag_prompt(query, context)
return self.generate(prompt)
def _build_rag_prompt(self, query: str, context: str) -> str:
"""Build WorkWise-style RAG prompt"""
return f"""
You are WorkWise, an AI assistant specialized in analyzing Jira project data.
Answer the user's question based only on the context.
Context:
{context}
User Question: {query}
Provide a clear, concise answer.
If the context doesn't contain enough information, say so.
""".strip()
# Global instance
generator = GeneratorService()