securedocai / inference_client.py
navid72m's picture
Update inference_client.py
688c00f verified
import os
import logging
from typing import Dict, Any, Optional
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import warnings
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
class GemmaInferenceClient:
"""
Ultra-simplified inference client optimized for HuggingFace Spaces.
Focuses on reliability over advanced features.
"""
def __init__(self, model_name: str = None):
"""Initialize with the most reliable model configuration"""
# Prioritize small, reliable models
self.available_models = [
"microsoft/DialoGPT-small", # 117MB - very reliable
"distilgpt2", # 353MB - stable GPT-2 variant
"gpt2", # 548MB - original GPT-2
]
# Try Gemma only if we can access it
if self._check_gemma_access():
self.available_models.insert(0, "google/gemma-3-1b-it")
self.model_name = None
self.tokenizer = None
self.model = None
self.pipeline = None
# Initialize the best available model
self._initialize_best_model()
def _check_gemma_access(self) -> bool:
"""Check if we can access Gemma models"""
try:
from huggingface_hub import login
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN')
if hf_token:
login(token=hf_token)
return True
elif os.getenv('SPACE_ID'):
return True # May have access in HF Spaces
except:
pass
return False
def _initialize_best_model(self):
"""Try models in order until one works"""
for model_name in self.available_models:
try:
logger.info(f"🚀 Trying model: {model_name}")
self._load_simple_model(model_name)
self.model_name = model_name
logger.info(f"✅ Successfully loaded: {model_name}")
return
except Exception as e:
logger.warning(f"⚠️ Model {model_name} failed: {str(e)[:100]}")
self._cleanup_failed_model()
continue
raise RuntimeError("❌ All models failed to load")
def _cleanup_failed_model(self):
"""Clean up after failed model load"""
if self.model:
del self.model
self.model = None
if self.tokenizer:
del self.tokenizer
self.tokenizer = None
if self.pipeline:
del self.pipeline
self.pipeline = None
# Force memory cleanup
torch.cuda.empty_cache() if torch.cuda.is_available() else None
import gc
gc.collect()
def _load_simple_model(self, model_name: str):
"""Load model with ultra-simple configuration"""
# Load tokenizer with minimal config
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True,
trust_remote_code=True
)
# Ensure we have required tokens
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with absolute minimal config
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use FP32 for maximum stability
low_cpu_mem_usage=True,
trust_remote_code=True
)
# Create pipeline without any device specifications
# Let transformers handle device placement automatically
if "gemma" in model_name.lower():
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
return_full_text=False
)
else:
# For DialoGPT and GPT-2 models, try conversational first
try:
self.pipeline = pipeline(
"conversational",
model=self.model,
tokenizer=self.tokenizer
)
except Exception:
# Fallback to text generation if conversational fails
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
return_full_text=False
)
def generate_response(
self,
query: str,
context: str,
temperature: float = 0.3,
max_tokens: int = 128,
**kwargs
) -> Dict[str, Any]:
"""Generate response with maximum reliability"""
start_time = time.time()
try:
# Create appropriate prompt for the model
if "gemma" in self.model_name.lower():
prompt = self._create_gemma_prompt(query, context)
else:
prompt = self._create_simple_prompt(query, context)
# Generate with the appropriate pipeline
if hasattr(self.pipeline, 'task') and self.pipeline.task == "conversational":
response = self._generate_conversational(prompt, max_tokens)
else:
response = self._generate_text(prompt, temperature, max_tokens)
# Clean and validate response
response = self._clean_response(response)
generation_time = time.time() - start_time
return {
"response": response,
"generation_time": generation_time,
"model": self.model_name,
"success": True
}
except Exception as e:
logger.error(f"❌ Generation error: {e}")
return {
"response": "I apologize, but I encountered an error. Please try rephrasing your question.",
"generation_time": time.time() - start_time,
"model": self.model_name,
"error": str(e),
"success": False
}
def _create_gemma_prompt(self, query: str, context: str) -> str:
"""Create Gemma-optimized prompt"""
return f"""Based on the following context, answer the question concisely and accurately.
Context: {context[:1200]}
Question: {query}
Answer:"""
def _create_simple_prompt(self, query: str, context: str) -> str:
"""Create simple prompt for other models"""
return f"Context: {context[:800]}\n\nQuestion: {query}\n\nAnswer:"
def _generate_conversational(self, prompt: str, max_tokens: int) -> str:
"""Generate using conversational pipeline"""
from transformers import Conversation
conversation = Conversation(prompt)
result = self.pipeline(conversation, max_length=min(max_tokens + 50, 200))
return result.generated_responses[-1] if result.generated_responses else ""
def _generate_text(self, prompt: str, temperature: float, max_tokens: int) -> str:
"""Generate using text generation pipeline"""
outputs = self.pipeline(
prompt,
max_new_tokens=min(max_tokens, 100),
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1,
clean_up_tokenization_spaces=True
)
return outputs[0]["generated_text"] if outputs else ""
def _clean_response(self, text: str) -> str:
"""Clean and validate response"""
if not text or not text.strip():
return "I couldn't provide a specific answer based on the available information."
# Remove prompt artifacts
text = text.strip()
# Remove common prefixes
prefixes = ["Answer:", "Response:", "Output:", "A:", "Question:", "Context:"]
for prefix in prefixes:
if text.startswith(prefix):
text = text[len(prefix):].strip()
# Basic deduplication
sentences = [s.strip() for s in text.split('.') if s.strip()]
unique_sentences = []
for sentence in sentences[:3]: # Limit to 3 sentences max
if sentence and sentence not in unique_sentences:
unique_sentences.append(sentence)
result = '. '.join(unique_sentences)
if result and not result.endswith('.'):
result += '.'
return result if result else "I couldn't generate a complete response."
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
"model_name": self.model_name,
"available_models": self.available_models,
"loaded": self.model is not None,
"pipeline_task": getattr(self.pipeline, 'task', 'unknown') if self.pipeline else None
}
def clear_cache(self):
"""Clear memory cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
logger.info("🧹 Memory cache cleared")
def __del__(self):
"""Cleanup when object is destroyed"""
self._cleanup_failed_model()