Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Header, Depends | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| from datetime import datetime | |
| import torch | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| import time | |
| #from dotenv import load_dotenv | |
| import os | |
| #load_dotenv() | |
| app = FastAPI() | |
| API_KEY = os.getenv("API_KEY") | |
| # Configuration | |
| API_KEYS = { | |
| API_KEY : "user1" # In production, use a secure database | |
| } | |
| # Initialize model and tokenizer with smaller model for Spaces | |
| MODEL_NAME = "tuner007/pegasus_paraphrase" | |
| print("Loading model and tokenizer...") | |
| tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache") | |
| model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache") | |
| device = "cpu" # Force CPU for Spaces deployment | |
| model = model.to(device) | |
| print("Model and tokenizer loaded successfully!") | |
| class TextRequest(BaseModel): | |
| text: str | |
| style: Optional[str] = "standard" | |
| num_variations: Optional[int] = 1 | |
| class BatchRequest(BaseModel): | |
| texts: List[str] | |
| style: Optional[str] = "standard" | |
| num_variations: Optional[int] = 1 | |
| async def verify_api_key(api_key: str = Header(..., name="X-API-Key")): | |
| if api_key not in API_KEYS: | |
| raise HTTPException(status_code=403, detail="Invalid API key") | |
| return api_key | |
| def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]: | |
| try: | |
| # Get parameters based on style | |
| params = { | |
| "standard": {"temperature": 1.5, "top_k": 80}, | |
| "formal": {"temperature": 1.0, "top_k": 50}, | |
| "casual": {"temperature": 1.6, "top_k": 100}, | |
| "creative": {"temperature": 2.8, "top_k": 170}, | |
| }.get(style, {"temperature": 1.0, "top_k": 50}) | |
| # Tokenize the input text | |
| inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device) | |
| # Generate paraphrases | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=10000, | |
| num_return_sequences=num_variations, | |
| num_beams=10, | |
| temperature=params["temperature"], | |
| top_k=params["top_k"], | |
| do_sample=True, | |
| early_stopping=True, | |
| ) | |
| # Decode the generated outputs | |
| paraphrases = [ | |
| tokenizer.decode(output, skip_special_tokens=True) | |
| for output in outputs | |
| ] | |
| return paraphrases | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}") | |
| async def root(): | |
| return {"message": "Paraphrase API is running. Use /docs for API documentation."} | |
| async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)): | |
| try: | |
| start_time = time.time() | |
| paraphrases = generate_paraphrase( | |
| request.text, | |
| request.style, | |
| request.num_variations | |
| ) | |
| processing_time = time.time() - start_time | |
| return { | |
| "status": "success", | |
| "original_text": request.text, | |
| "paraphrased_texts": paraphrases, | |
| "style": request.style, | |
| "processing_time": f"{processing_time:.2f} seconds", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)): | |
| try: | |
| start_time = time.time() | |
| results = [] | |
| for text in request.texts: | |
| paraphrases = generate_paraphrase( | |
| text, | |
| request.style, | |
| request.num_variations | |
| ) | |
| results.append({ | |
| "original_text": text, | |
| "paraphrased_texts": paraphrases, | |
| "style": request.style | |
| }) | |
| processing_time = time.time() - start_time | |
| return { | |
| "status": "success", | |
| "results": results, | |
| "total_texts_processed": len(request.texts), | |
| "processing_time": f"{processing_time:.2f} seconds", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |