Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import Dict, Any, Optional | |
| import time | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import warnings | |
| logger = logging.getLogger(__name__) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| class GemmaInferenceClient: | |
| """ | |
| Ultra-simplified inference client optimized for HuggingFace Spaces. | |
| Focuses on reliability over advanced features. | |
| """ | |
| def __init__(self, model_name: str = None): | |
| """Initialize with the most reliable model configuration""" | |
| # Prioritize small, reliable models | |
| self.available_models = [ | |
| "microsoft/DialoGPT-small", # 117MB - very reliable | |
| "distilgpt2", # 353MB - stable GPT-2 variant | |
| "gpt2", # 548MB - original GPT-2 | |
| ] | |
| # Try Gemma only if we can access it | |
| if self._check_gemma_access(): | |
| self.available_models.insert(0, "google/gemma-3-1b-it") | |
| self.model_name = None | |
| self.tokenizer = None | |
| self.model = None | |
| self.pipeline = None | |
| # Initialize the best available model | |
| self._initialize_best_model() | |
| def _check_gemma_access(self) -> bool: | |
| """Check if we can access Gemma models""" | |
| try: | |
| from huggingface_hub import login | |
| hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| if hf_token: | |
| login(token=hf_token) | |
| return True | |
| elif os.getenv('SPACE_ID'): | |
| return True # May have access in HF Spaces | |
| except: | |
| pass | |
| return False | |
| def _initialize_best_model(self): | |
| """Try models in order until one works""" | |
| for model_name in self.available_models: | |
| try: | |
| logger.info(f"🚀 Trying model: {model_name}") | |
| self._load_simple_model(model_name) | |
| self.model_name = model_name | |
| logger.info(f"✅ Successfully loaded: {model_name}") | |
| return | |
| except Exception as e: | |
| logger.warning(f"⚠️ Model {model_name} failed: {str(e)[:100]}") | |
| self._cleanup_failed_model() | |
| continue | |
| raise RuntimeError("❌ All models failed to load") | |
| def _cleanup_failed_model(self): | |
| """Clean up after failed model load""" | |
| if self.model: | |
| del self.model | |
| self.model = None | |
| if self.tokenizer: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| if self.pipeline: | |
| del self.pipeline | |
| self.pipeline = None | |
| # Force memory cleanup | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| import gc | |
| gc.collect() | |
| def _load_simple_model(self, model_name: str): | |
| """Load model with ultra-simple configuration""" | |
| # Load tokenizer with minimal config | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| use_fast=True, | |
| trust_remote_code=True | |
| ) | |
| # Ensure we have required tokens | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Load model with absolute minimal config | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # Use FP32 for maximum stability | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| # Create pipeline without any device specifications | |
| # Let transformers handle device placement automatically | |
| if "gemma" in model_name.lower(): | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| return_full_text=False | |
| ) | |
| else: | |
| # For DialoGPT and GPT-2 models, try conversational first | |
| try: | |
| self.pipeline = pipeline( | |
| "conversational", | |
| model=self.model, | |
| tokenizer=self.tokenizer | |
| ) | |
| except Exception: | |
| # Fallback to text generation if conversational fails | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| return_full_text=False | |
| ) | |
| def generate_response( | |
| self, | |
| query: str, | |
| context: str, | |
| temperature: float = 0.3, | |
| max_tokens: int = 128, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| """Generate response with maximum reliability""" | |
| start_time = time.time() | |
| try: | |
| # Create appropriate prompt for the model | |
| if "gemma" in self.model_name.lower(): | |
| prompt = self._create_gemma_prompt(query, context) | |
| else: | |
| prompt = self._create_simple_prompt(query, context) | |
| # Generate with the appropriate pipeline | |
| if hasattr(self.pipeline, 'task') and self.pipeline.task == "conversational": | |
| response = self._generate_conversational(prompt, max_tokens) | |
| else: | |
| response = self._generate_text(prompt, temperature, max_tokens) | |
| # Clean and validate response | |
| response = self._clean_response(response) | |
| generation_time = time.time() - start_time | |
| return { | |
| "response": response, | |
| "generation_time": generation_time, | |
| "model": self.model_name, | |
| "success": True | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ Generation error: {e}") | |
| return { | |
| "response": "I apologize, but I encountered an error. Please try rephrasing your question.", | |
| "generation_time": time.time() - start_time, | |
| "model": self.model_name, | |
| "error": str(e), | |
| "success": False | |
| } | |
| def _create_gemma_prompt(self, query: str, context: str) -> str: | |
| """Create Gemma-optimized prompt""" | |
| return f"""Based on the following context, answer the question concisely and accurately. | |
| Context: {context[:1200]} | |
| Question: {query} | |
| Answer:""" | |
| def _create_simple_prompt(self, query: str, context: str) -> str: | |
| """Create simple prompt for other models""" | |
| return f"Context: {context[:800]}\n\nQuestion: {query}\n\nAnswer:" | |
| def _generate_conversational(self, prompt: str, max_tokens: int) -> str: | |
| """Generate using conversational pipeline""" | |
| from transformers import Conversation | |
| conversation = Conversation(prompt) | |
| result = self.pipeline(conversation, max_length=min(max_tokens + 50, 200)) | |
| return result.generated_responses[-1] if result.generated_responses else "" | |
| def _generate_text(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
| """Generate using text generation pipeline""" | |
| outputs = self.pipeline( | |
| prompt, | |
| max_new_tokens=min(max_tokens, 100), | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| num_return_sequences=1, | |
| clean_up_tokenization_spaces=True | |
| ) | |
| return outputs[0]["generated_text"] if outputs else "" | |
| def _clean_response(self, text: str) -> str: | |
| """Clean and validate response""" | |
| if not text or not text.strip(): | |
| return "I couldn't provide a specific answer based on the available information." | |
| # Remove prompt artifacts | |
| text = text.strip() | |
| # Remove common prefixes | |
| prefixes = ["Answer:", "Response:", "Output:", "A:", "Question:", "Context:"] | |
| for prefix in prefixes: | |
| if text.startswith(prefix): | |
| text = text[len(prefix):].strip() | |
| # Basic deduplication | |
| sentences = [s.strip() for s in text.split('.') if s.strip()] | |
| unique_sentences = [] | |
| for sentence in sentences[:3]: # Limit to 3 sentences max | |
| if sentence and sentence not in unique_sentences: | |
| unique_sentences.append(sentence) | |
| result = '. '.join(unique_sentences) | |
| if result and not result.endswith('.'): | |
| result += '.' | |
| return result if result else "I couldn't generate a complete response." | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get model information""" | |
| return { | |
| "model_name": self.model_name, | |
| "available_models": self.available_models, | |
| "loaded": self.model is not None, | |
| "pipeline_task": getattr(self.pipeline, 'task', 'unknown') if self.pipeline else None | |
| } | |
| def clear_cache(self): | |
| """Clear memory cache""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| import gc | |
| gc.collect() | |
| logger.info("🧹 Memory cache cleared") | |
| def __del__(self): | |
| """Cleanup when object is destroyed""" | |
| self._cleanup_failed_model() |