Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Ultra-Low Latency Multilingual Caption Simplification API with Phi-2 | |
| Optimized for <1 second response time on Hugging Face Spaces | |
| """ | |
| import os | |
| import time | |
| import gc | |
| import re | |
| import json | |
| import torch | |
| from typing import Dict, List, Optional, Tuple | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, ConfigDict | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import uvicorn | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import tempfile | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Create a temporary directory for cache | |
| temp_dir = tempfile.mkdtemp() | |
| os.environ["TRANSFORMERS_CACHE"] = temp_dir | |
| os.environ["HF_HOME"] = temp_dir | |
| # Global variables for model caching | |
| model = None | |
| tokenizer = None | |
| # --- Configuration --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID = "microsoft/phi-2" | |
| logger.info(f"Running on device: {DEVICE}") | |
| # Pre-compiled regex patterns for speed | |
| ENGLISH_FILLERS_PATTERN = re.compile(r'\b(?:um|uh|er|ah|like|you know|so|well|basically|actually|literally|sort of|kind of)\b[,\.\?!]*\s*', re.IGNORECASE) | |
| HINDI_FILLERS_PATTERN = re.compile(r'\b(?:उम|उह|मतलब|आप समझे होंगे|वो|की|तोह|जी|ना|तो|हाँ|मैंने सोचा|अरे|वैसे)\b[,\.\?!]*\s*', re.IGNORECASE) | |
| SPACE_PATTERN = re.compile(r'\s+') | |
| PUNCT_PATTERN = re.compile(r'\s+([,.?!])') | |
| JSON_PATTERN = re.compile(r'\{.*\}', re.DOTALL) | |
| # Request/Response models with protected namespace configuration | |
| class TextInput(BaseModel): | |
| text: str | |
| class CaptionSegment(BaseModel): | |
| timestamp_start: float | |
| timestamp_end: float | |
| original_text: str | |
| class CaptionSegmentsInput(BaseModel): | |
| segments: List[CaptionSegment] | |
| class SimplifiedSegment(BaseModel): | |
| timestamp_start: float | |
| timestamp_end: float | |
| original_text: str | |
| simplified_text: str | |
| meaning: str | |
| class SimplifyResponse(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| simplified_text: str | |
| language: str | |
| latency_ms: float | |
| model_used: str | |
| class SimplifySegmentsResponse(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| segments: List[SimplifiedSegment] | |
| total_processing_time_ms: float | |
| model_used: str | |
| class HealthResponse(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| status: str | |
| model_loaded: str | |
| device: str | |
| # Application lifecycle | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| logger.info("Starting up caption simplification API with Phi-2...") | |
| # Don't load model at startup to avoid issues | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| cleanup_resources() | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Caption Simplification API", | |
| description="Ultra-low latency multilingual caption simplification with Phi-2", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| async def load_phi2_model(): | |
| """Load Phi-2 model on demand.""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading Phi-2 model...") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| cache_dir=temp_dir, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with optimizations | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| cache_dir=temp_dir, | |
| ) | |
| if DEVICE != "cuda": | |
| model = model.to(DEVICE) | |
| # Optimize model for inference | |
| model.eval() | |
| if DEVICE == "cuda": | |
| model.config.use_cache = True | |
| logger.info("Phi-2 model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading Phi-2 model: {str(e)}") | |
| return False | |
| def cleanup_resources(): | |
| """Clean up GPU memory.""" | |
| global model, tokenizer | |
| if model is not None: | |
| del model | |
| if tokenizer is not None: | |
| del tokenizer | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def detect_language(text: str) -> str: | |
| """Fast language detection using simple heuristics.""" | |
| hindi_pattern = re.compile(r'[\u0900-\u097F]') | |
| english_words = re.findall(r'\b[a-zA-Z]+\b', text) | |
| total_words = len(text.split()) | |
| if hindi_pattern.search(text): | |
| english_ratio = len(english_words) / total_words if total_words > 0 else 0 | |
| if english_ratio > 0.4: | |
| return "Hinglish" | |
| else: | |
| return "Hindi" | |
| else: | |
| return "English" | |
| def remove_fillers_fast(text: str) -> str: | |
| """Optimized filler removal using pre-compiled regex.""" | |
| result = text | |
| result = ENGLISH_FILLERS_PATTERN.sub(' ', result) | |
| result = HINDI_FILLERS_PATTERN.sub(' ', result) | |
| # Clean up | |
| result = SPACE_PATTERN.sub(' ', result).strip() | |
| result = PUNCT_PATTERN.sub(r'\1', result) | |
| return result | |
| def create_simplify_prompt(text: str, language: str) -> str: | |
| """Create a prompt for text simplification.""" | |
| return f"""You are a text simplification assistant. Your task is to simplify the given text while preserving the original language. | |
| Instructions: | |
| 1. Remove filler words (um, uh, er, ah, like, you know, so, well, उम, उह, मतलब, etc.) | |
| 2. Simplify complex vocabulary and sentence structure | |
| 3. Preserve the original language completely | |
| 4. Keep the meaning intact | |
| Example: | |
| Input: "The project, um, needs, uh, more time." | |
| Output: "The project needs more time." | |
| Input: "मेरा प्रोजेक्ट, उम, अच्छा चल रहा है।" | |
| Output: "मेरा प्रोजेक्ट अच्छा चल रहा है।" | |
| Now process this text: | |
| "{text}" | |
| Output:""" | |
| def create_meaning_prompt(text: str, language: str) -> str: | |
| """Create a prompt for extracting meaning from text.""" | |
| return f"""You are a language assistant. Your task is to provide a simple meaning or explanation for the given text in the same language. | |
| Instructions: | |
| 1. Provide a simple meaning or explanation | |
| 2. Keep it concise (under 15 words) | |
| 3. Preserve the original language | |
| Example: | |
| Input: "The project needs more time." | |
| Output: "The project requires additional time to complete." | |
| Input: "मेरा प्रोजेक्ट अच्छा चल रहा है।" | |
| Output: "मेरा प्रोजेक्ट सफलतापूर्वक आगे बढ़ रहा है।" | |
| Now process this text: | |
| "{text}" | |
| Output:""" | |
| def extract_response_from_text(text: str) -> str: | |
| """Extract the response from the model's output.""" | |
| # Remove any leading/trailing whitespace | |
| response = text.strip() | |
| # Remove any remaining prompt text | |
| if 'Output:' in response: | |
| response = response.split('Output:')[-1].strip() | |
| # Remove any JSON formatting if present | |
| if response.startswith('{') and response.endswith('}'): | |
| try: | |
| json_data = json.loads(response) | |
| if 'simplified_text' in json_data: | |
| response = json_data['simplified_text'] | |
| elif 'meaning' in json_data: | |
| response = json_data['meaning'] | |
| except: | |
| pass | |
| # Clean up any remaining artifacts | |
| response = re.sub(r'["\'\[\]]', '', response) | |
| return response | |
| async def ensure_model_loaded(): | |
| """Ensure model is loaded, load if not already loaded.""" | |
| global model, tokenizer | |
| if model is None: | |
| return await load_phi2_model() | |
| return True | |
| async def simplify_text_async(text: str) -> Tuple[str, float]: | |
| """Optimized text simplification with minimal latency using Phi-2.""" | |
| start_time = time.time() | |
| # Detect language | |
| language = detect_language(text) | |
| # Fast filler removal | |
| text_without_fillers = remove_fillers_fast(text) | |
| # For Hindi and Hinglish, use rule-based approach to avoid language change | |
| if language in ["Hindi", "Hinglish"]: | |
| return text_without_fillers, (time.time() - start_time) * 1000 | |
| # Ensure model is loaded | |
| model_loaded = await ensure_model_loaded() | |
| # If model failed to load, return rule-based result | |
| if not model_loaded or model is None: | |
| logger.warning("Model not available, using rule-based result") | |
| return text_without_fillers, (time.time() - start_time) * 1000 | |
| # Create prompt | |
| prompt = create_simplify_prompt(text_without_fillers, language) | |
| # Tokenize with optimizations | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding=False | |
| ).to(DEVICE) | |
| # Generate with optimized parameters | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=64, | |
| min_new_tokens=5, | |
| temperature=0.1, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # Fast decoding | |
| response = outputs[0][inputs.input_ids.shape[-1]:] | |
| generated_text = tokenizer.decode(response, skip_special_tokens=True) | |
| # Extract the simplified text | |
| simplified_text = extract_response_from_text(generated_text) | |
| # If no result, fall back to rule-based | |
| if not simplified_text: | |
| simplified_text = text_without_fillers | |
| latency = (time.time() - start_time) * 1000 | |
| return simplified_text, latency | |
| async def get_meaning_async(text: str) -> Tuple[str, float]: | |
| """Get the meaning of text using Phi-2.""" | |
| start_time = time.time() | |
| # Detect language | |
| language = detect_language(text) | |
| # Ensure model is loaded | |
| model_loaded = await ensure_model_loaded() | |
| # If model failed to load, return empty result | |
| if not model_loaded or model is None: | |
| logger.warning("Model not available, cannot generate meaning") | |
| return "", (time.time() - start_time) * 1000 | |
| # Create prompt | |
| prompt = create_meaning_prompt(text, language) | |
| # Tokenize with optimizations | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, | |
| padding=False | |
| ).to(DEVICE) | |
| # Generate with optimized parameters | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=64, | |
| min_new_tokens=5, | |
| temperature=0.1, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # Fast decoding | |
| response = outputs[0][inputs.input_ids.shape[-1]:] | |
| generated_text = tokenizer.decode(response, skip_special_tokens=True) | |
| # Extract the meaning | |
| meaning = extract_response_from_text(generated_text) | |
| latency = (time.time() - start_time) * 1000 | |
| return meaning, latency | |
| # API Routes | |
| async def root(): | |
| """Root endpoint.""" | |
| return {"message": "Caption Simplification API with Phi-2", "version": "1.0.0"} | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| model_status = "Phi-2" if model is not None else "None" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model_status, | |
| "device": DEVICE | |
| } | |
| async def simplify_text(input_data: TextInput): | |
| """Simplify text by removing filler words.""" | |
| try: | |
| # Process text | |
| simplified_text, latency = await simplify_text_async(input_data.text) | |
| language = detect_language(input_data.text) | |
| return SimplifyResponse( | |
| simplified_text=simplified_text, | |
| language=language, | |
| latency_ms=latency, | |
| model_used="Phi-2" if model is not None else "Rule-based" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing text: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def simplify_segments(input_data: CaptionSegmentsInput): | |
| """Simplify multiple caption segments with timestamps.""" | |
| try: | |
| start_time = time.time() | |
| simplified_segments = [] | |
| # Process each segment | |
| for segment in input_data.segments: | |
| # Simplify the text | |
| simplified_text, _ = await simplify_text_async(segment.original_text) | |
| # Get the meaning | |
| meaning, _ = await get_meaning_async(simplified_text) | |
| # Create simplified segment | |
| simplified_segment = SimplifiedSegment( | |
| timestamp_start=segment.timestamp_start, | |
| timestamp_end=segment.timestamp_end, | |
| original_text=segment.original_text, | |
| simplified_text=simplified_text, | |
| meaning=meaning | |
| ) | |
| simplified_segments.append(simplified_segment) | |
| total_processing_time = (time.time() - start_time) * 1000 | |
| return SimplifySegmentsResponse( | |
| segments=simplified_segments, | |
| total_processing_time_ms=total_processing_time, | |
| model_used="Phi-2" if model is not None else "Rule-based" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing segments: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |