import os import logging from typing import Dict, Any, Optional import time import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import warnings logger = logging.getLogger(__name__) warnings.filterwarnings("ignore", category=UserWarning) class GemmaInferenceClient: """ Ultra-simplified inference client optimized for HuggingFace Spaces. Focuses on reliability over advanced features. """ def __init__(self, model_name: str = None): """Initialize with the most reliable model configuration""" # Prioritize small, reliable models self.available_models = [ "microsoft/DialoGPT-small", # 117MB - very reliable "distilgpt2", # 353MB - stable GPT-2 variant "gpt2", # 548MB - original GPT-2 ] # Try Gemma only if we can access it if self._check_gemma_access(): self.available_models.insert(0, "google/gemma-3-1b-it") self.model_name = None self.tokenizer = None self.model = None self.pipeline = None # Initialize the best available model self._initialize_best_model() def _check_gemma_access(self) -> bool: """Check if we can access Gemma models""" try: from huggingface_hub import login hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') if hf_token: login(token=hf_token) return True elif os.getenv('SPACE_ID'): return True # May have access in HF Spaces except: pass return False def _initialize_best_model(self): """Try models in order until one works""" for model_name in self.available_models: try: logger.info(f"๐Ÿš€ Trying model: {model_name}") self._load_simple_model(model_name) self.model_name = model_name logger.info(f"โœ… Successfully loaded: {model_name}") return except Exception as e: logger.warning(f"โš ๏ธ Model {model_name} failed: {str(e)[:100]}") self._cleanup_failed_model() continue raise RuntimeError("โŒ All models failed to load") def _cleanup_failed_model(self): """Clean up after failed model load""" if self.model: del self.model self.model = None if self.tokenizer: del self.tokenizer self.tokenizer = None if self.pipeline: del self.pipeline self.pipeline = None # Force memory cleanup torch.cuda.empty_cache() if torch.cuda.is_available() else None import gc gc.collect() def _load_simple_model(self, model_name: str): """Load model with ultra-simple configuration""" # Load tokenizer with minimal config self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True, trust_remote_code=True ) # Ensure we have required tokens if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with absolute minimal config self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use FP32 for maximum stability low_cpu_mem_usage=True, trust_remote_code=True ) # Create pipeline without any device specifications # Let transformers handle device placement automatically if "gemma" in model_name.lower(): self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, return_full_text=False ) else: # For DialoGPT and GPT-2 models, try conversational first try: self.pipeline = pipeline( "conversational", model=self.model, tokenizer=self.tokenizer ) except Exception: # Fallback to text generation if conversational fails self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, return_full_text=False ) def generate_response( self, query: str, context: str, temperature: float = 0.3, max_tokens: int = 128, **kwargs ) -> Dict[str, Any]: """Generate response with maximum reliability""" start_time = time.time() try: # Create appropriate prompt for the model if "gemma" in self.model_name.lower(): prompt = self._create_gemma_prompt(query, context) else: prompt = self._create_simple_prompt(query, context) # Generate with the appropriate pipeline if hasattr(self.pipeline, 'task') and self.pipeline.task == "conversational": response = self._generate_conversational(prompt, max_tokens) else: response = self._generate_text(prompt, temperature, max_tokens) # Clean and validate response response = self._clean_response(response) generation_time = time.time() - start_time return { "response": response, "generation_time": generation_time, "model": self.model_name, "success": True } except Exception as e: logger.error(f"โŒ Generation error: {e}") return { "response": "I apologize, but I encountered an error. Please try rephrasing your question.", "generation_time": time.time() - start_time, "model": self.model_name, "error": str(e), "success": False } def _create_gemma_prompt(self, query: str, context: str) -> str: """Create Gemma-optimized prompt""" return f"""Based on the following context, answer the question concisely and accurately. Context: {context[:1200]} Question: {query} Answer:""" def _create_simple_prompt(self, query: str, context: str) -> str: """Create simple prompt for other models""" return f"Context: {context[:800]}\n\nQuestion: {query}\n\nAnswer:" def _generate_conversational(self, prompt: str, max_tokens: int) -> str: """Generate using conversational pipeline""" from transformers import Conversation conversation = Conversation(prompt) result = self.pipeline(conversation, max_length=min(max_tokens + 50, 200)) return result.generated_responses[-1] if result.generated_responses else "" def _generate_text(self, prompt: str, temperature: float, max_tokens: int) -> str: """Generate using text generation pipeline""" outputs = self.pipeline( prompt, max_new_tokens=min(max_tokens, 100), temperature=temperature, do_sample=temperature > 0, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, num_return_sequences=1, clean_up_tokenization_spaces=True ) return outputs[0]["generated_text"] if outputs else "" def _clean_response(self, text: str) -> str: """Clean and validate response""" if not text or not text.strip(): return "I couldn't provide a specific answer based on the available information." # Remove prompt artifacts text = text.strip() # Remove common prefixes prefixes = ["Answer:", "Response:", "Output:", "A:", "Question:", "Context:"] for prefix in prefixes: if text.startswith(prefix): text = text[len(prefix):].strip() # Basic deduplication sentences = [s.strip() for s in text.split('.') if s.strip()] unique_sentences = [] for sentence in sentences[:3]: # Limit to 3 sentences max if sentence and sentence not in unique_sentences: unique_sentences.append(sentence) result = '. '.join(unique_sentences) if result and not result.endswith('.'): result += '.' return result if result else "I couldn't generate a complete response." def get_model_info(self) -> Dict[str, Any]: """Get model information""" return { "model_name": self.model_name, "available_models": self.available_models, "loaded": self.model is not None, "pipeline_task": getattr(self.pipeline, 'task', 'unknown') if self.pipeline else None } def clear_cache(self): """Clear memory cache""" if torch.cuda.is_available(): torch.cuda.empty_cache() import gc gc.collect() logger.info("๐Ÿงน Memory cache cleared") def __del__(self): """Cleanup when object is destroyed""" self._cleanup_failed_model()