Spaces:
Sleeping
Sleeping
| from typing import Any, Dict | |
| """ | |
| HuggingFace LLM Service | |
| Provides LLM capabilities using HuggingFace Inference API. | |
| Designed to work with the free tier and integrate with the RAG pipeline. | |
| """ | |
| import logging | |
| import os | |
| import time | |
| import requests | |
| logger = logging.getLogger(__name__) | |
| class HFLLMService: | |
| """ | |
| LLM service using HuggingFace Inference API. | |
| Uses free-tier models and provides fallback capabilities. | |
| """ | |
| def __init__(self, model_name: str = "gpt2"): | |
| """ | |
| Initialize HF LLM service. | |
| Args: | |
| model_name: HuggingFace model to use for text generation | |
| """ | |
| self.model_name = model_name | |
| self.api_url = f"https://router.huggingface.co/hf-inference/models/{model_name}" | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| if not self.hf_token: | |
| logger.warning("No HF_TOKEN found - will use rate-limited public inference") | |
| self.headers = {} | |
| if self.hf_token: | |
| self.headers["Authorization"] = f"Bearer {self.hf_token}" | |
| self.headers["Content-Type"] = "application/json" | |
| logger.info(f"HFLLMService initialized with model: {model_name}") | |
| def generate_response(self, prompt: str, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Generate response using HuggingFace Inference API. | |
| Args: | |
| prompt: Input prompt for the model | |
| **kwargs: Additional parameters (max_tokens, temperature, etc.) | |
| Returns: | |
| Dict containing the response and metadata | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Prepare request payload for text generation | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": min(kwargs.get("max_tokens", 150), 300), # Cap at 300 for free tier | |
| "temperature": kwargs.get("temperature", 0.8), | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.2, | |
| "return_full_text": False, # Important for GPT-2 | |
| }, | |
| } | |
| # Make API request | |
| response = requests.post( | |
| self.api_url, | |
| headers=self.headers, | |
| json=payload, | |
| timeout=kwargs.get("timeout", 30), | |
| ) | |
| response_time = time.time() - start_time | |
| if response.status_code == 200: | |
| result = response.json() | |
| # Handle different response formats | |
| if isinstance(result, list) and len(result) > 0: | |
| # Standard text generation format | |
| if "generated_text" in result[0]: | |
| generated_text = result[0]["generated_text"] | |
| # Remove the input prompt if it's included | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt) :].strip() | |
| else: | |
| generated_text = str(result[0]) | |
| elif isinstance(result, dict): | |
| if "generated_text" in result: | |
| generated_text = result["generated_text"] | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt) :].strip() | |
| elif "answer" in result: | |
| generated_text = result["answer"] | |
| else: | |
| generated_text = str(result) | |
| else: | |
| generated_text = str(result) | |
| # Clean up the response | |
| generated_text = generated_text.strip() | |
| if not generated_text: | |
| generated_text = ( | |
| "I apologize, but I couldn't generate a proper response. Please try rephrasing your question." | |
| ) | |
| return { | |
| "content": generated_text, | |
| "provider": "huggingface", | |
| "model": self.model_name, | |
| "response_time": response_time, | |
| "success": True, | |
| "usage": { | |
| "prompt_tokens": len(prompt.split()), | |
| "completion_tokens": len(generated_text.split()), | |
| "total_tokens": len(prompt.split()) + len(generated_text.split()), | |
| }, | |
| } | |
| else: | |
| error_msg = f"HF API error {response.status_code}: {response.text}" | |
| logger.error(error_msg) | |
| logger.error(f"Failed API request to: {self.api_url}") | |
| logger.error(f"Request payload: {payload}") | |
| # Return fallback response for common policy questions | |
| fallback_response = self._get_fallback_response(prompt) | |
| return { | |
| "content": fallback_response, | |
| "provider": "huggingface_fallback", | |
| "model": "fallback", | |
| "response_time": response_time, | |
| "success": True, | |
| "error_message": error_msg, | |
| "usage": { | |
| "prompt_tokens": len(prompt.split()), | |
| "completion_tokens": len(fallback_response.split()), | |
| "total_tokens": len(prompt.split()) + len(fallback_response.split()), | |
| }, | |
| } | |
| except Exception as e: | |
| response_time = time.time() - start_time | |
| error_msg = f"HF LLM service error: {str(e)}" | |
| logger.error(error_msg) | |
| # Return fallback response | |
| fallback_response = self._get_fallback_response(prompt) | |
| return { | |
| "content": fallback_response, | |
| "provider": "huggingface_fallback", | |
| "model": "fallback", | |
| "response_time": response_time, | |
| "success": True, | |
| "error_message": error_msg, | |
| "usage": { | |
| "prompt_tokens": len(prompt.split()), | |
| "completion_tokens": len(fallback_response.split()), | |
| "total_tokens": len(prompt.split()) + len(fallback_response.split()), | |
| }, | |
| } | |
| def _get_fallback_response(self, prompt: str) -> str: | |
| """ | |
| Generate a fallback response when the API is unavailable. | |
| Args: | |
| prompt: The original prompt | |
| Returns: | |
| A helpful fallback response | |
| """ | |
| prompt_lower = prompt.lower() | |
| # Check for common policy topics | |
| if any(word in prompt_lower for word in ["vacation", "time off", "pto", "leave"]): | |
| return ( | |
| "Based on company policy documents, employees typically accrue vacation time " | |
| "based on their length of service. New employees usually start with 2-3 weeks of " | |
| "vacation per year, with additional time earned based on tenure. Please consult your " | |
| "employee handbook or HR department for specific details about your vacation accrual " | |
| "rate and current balance." | |
| ) | |
| elif any(word in prompt_lower for word in ["sick", "medical", "health"]): | |
| return ( | |
| "Our company provides sick leave benefits as required by law and company policy. " | |
| "Sick time can be used for your own illness or to care for qualifying family members. " | |
| "Please refer to your employee handbook for specific details about sick leave accrual, " | |
| "usage policies, and any required documentation." | |
| ) | |
| elif any(word in prompt_lower for word in ["benefits", "insurance", "health plan"]): | |
| return ( | |
| "The company offers a comprehensive benefits package including health insurance, dental, " | |
| "vision, and other benefits. Enrollment periods and benefit details are outlined in your " | |
| "employee handbook. Contact HR or check the employee portal for current benefit options " | |
| "and enrollment information." | |
| ) | |
| elif any(word in prompt_lower for word in ["remote", "work from home", "wfh"]): | |
| return ( | |
| "Remote work policies vary by role and department. Many positions offer flexible work " | |
| "arrangements including hybrid or full remote options. Please consult your manager and HR " | |
| "about remote work eligibility for your specific position and any required approvals or equipment." | |
| ) | |
| else: | |
| return ( | |
| "I apologize, but I'm unable to provide a specific answer to your question at the " | |
| "moment due to technical limitations. For accurate information about company policies, " | |
| "please consult your employee handbook, contact HR directly, or check the employee portal. " | |
| "Your question is important and HR will be able to provide you with the most current " | |
| "and relevant policy information." | |
| ) | |
| from typing import Any, Dict, List | |
| def chat_completion(self, messages: List[Dict[str, str]], **kwargs: Any) -> str: | |
| """ | |
| Process a conversation with multiple messages. | |
| Args: | |
| messages: List of message dictionaries with 'role' and 'content' | |
| **kwargs: Additional parameters | |
| Returns: | |
| Generated response string | |
| """ | |
| # Convert messages to a single prompt | |
| prompt = "" | |
| for message in messages: | |
| role = message.get("role", "user") | |
| content = message.get("content", "") | |
| if role == "user": | |
| prompt += f"User: {content}\n" | |
| elif role == "assistant": | |
| prompt += f"Assistant: {content}\n" | |
| elif role == "system": | |
| prompt += f"System: {content}\n" | |
| prompt += "Assistant: " | |
| response = self.generate_response(prompt, **kwargs) | |
| return response.get( | |
| "content", | |
| "I apologize, but I'm unable to generate a response at the moment.", | |
| ) | |
| def health_check(self) -> bool: | |
| """ | |
| Check if the HF LLM service is operational. | |
| Returns: | |
| True if service is healthy, False otherwise | |
| """ | |
| try: | |
| # Simple test with minimal prompt | |
| test_response = self.generate_response("Hello", max_tokens=10, timeout=5) | |
| return test_response.get("success", False) | |
| except Exception as e: | |
| logger.error(f"HF LLM health check failed: {e}") | |
| return False | |