Spaces:
Runtime error
Runtime error
| # Simple implementation for translation using the BART model | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| app = FastAPI() | |
| # Define request model | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| max_length: int = 150 | |
| min_length: int = 40 | |
| # Download and cache the model during initialization | |
| # This happens only once when the app starts | |
| try: | |
| # Explicitly download to a specific directory with proper error handling | |
| cache_dir = "./model_cache" | |
| model_name = "facebook/bart-large-cnn" | |
| print(f"Loading tokenizer from {model_name}...") | |
| tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) | |
| print(f"Loading model from {model_name}...") | |
| model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) | |
| print("Model and tokenizer loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| async def translate_text(request: TranslationRequest): | |
| # Process the input text | |
| inputs = tokenizer(request.text, return_tensors="pt", max_length=1024, truncation=True) | |
| # Generate summary | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=request.max_length, | |
| min_length=request.min_length, | |
| num_beams=4, | |
| length_penalty=2.0, | |
| early_stopping=True | |
| ) | |
| # Decode the generated summary | |
| translation = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return {"summary": translation} | |
| # Basic health check endpoint | |
| async def health_check(): | |
| return {"status": "healthy", "model": "facebook/bart-large-cnn"} |