Spaces:
Sleeping
Sleeping
File size: 4,611 Bytes
ea9802b 01807b0 ea9802b 99838a4 e7739fb ea9802b e7739fb 99838a4 ea9802b 99838a4 ea9802b ef0e5c3 01807b0 ef0e5c3 ea9802b defb5c4 ea9802b ef0e5c3 ea9802b defb5c4 ea9802b defb5c4 ea9802b defb5c4 01807b0 defb5c4 ea9802b 19db783 488ca19 16297d9 ea9802b 01807b0 19db783 ea9802b 19db783 defb5c4 16297d9 488ca19 ea9802b 488ca19 ea9802b defb5c4 01807b0 ea9802b 488ca19 defb5c4 ea9802b 01807b0 ea9802b defb5c4 ea9802b defb5c4 ea9802b defb5c4 ea9802b defb5c4 ea9802b defb5c4 ea9802b 01807b0 ea9802b defb5c4 ea9802b 16297d9 defb5c4 ea9802b defb5c4 ea9802b defb5c4 ea9802b 01807b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from fastapi import FastAPI, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import time
#from dotenv import load_dotenv
import os
#load_dotenv()
app = FastAPI()
API_KEY = os.getenv("API_KEY")
# Configuration
API_KEYS = {
API_KEY : "user1" # In production, use a secure database
}
# Initialize model and tokenizer with smaller model for Spaces
MODEL_NAME = "tuner007/pegasus_paraphrase"
print("Loading model and tokenizer...")
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
device = "cpu" # Force CPU for Spaces deployment
model = model.to(device)
print("Model and tokenizer loaded successfully!")
class TextRequest(BaseModel):
text: str
style: Optional[str] = "standard"
num_variations: Optional[int] = 1
class BatchRequest(BaseModel):
texts: List[str]
style: Optional[str] = "standard"
num_variations: Optional[int] = 1
async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
if api_key not in API_KEYS:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
try:
# Get parameters based on style
params = {
"standard": {"temperature": 1.5, "top_k": 80},
"formal": {"temperature": 1.0, "top_k": 50},
"casual": {"temperature": 1.6, "top_k": 100},
"creative": {"temperature": 2.8, "top_k": 170},
}.get(style, {"temperature": 1.0, "top_k": 50})
# Tokenize the input text
inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device)
# Generate paraphrases
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=10000,
num_return_sequences=num_variations,
num_beams=10,
temperature=params["temperature"],
top_k=params["top_k"],
do_sample=True,
early_stopping=True,
)
# Decode the generated outputs
paraphrases = [
tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return paraphrases
except Exception as e:
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
@app.get("/")
async def root():
return {"message": "Paraphrase API is running. Use /docs for API documentation."}
@app.post("/api/paraphrase")
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
try:
start_time = time.time()
paraphrases = generate_paraphrase(
request.text,
request.style,
request.num_variations
)
processing_time = time.time() - start_time
return {
"status": "success",
"original_text": request.text,
"paraphrased_texts": paraphrases,
"style": request.style,
"processing_time": f"{processing_time:.2f} seconds",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-paraphrase")
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
try:
start_time = time.time()
results = []
for text in request.texts:
paraphrases = generate_paraphrase(
text,
request.style,
request.num_variations
)
results.append({
"original_text": text,
"paraphrased_texts": paraphrases,
"style": request.style
})
processing_time = time.time() - start_time
return {
"status": "success",
"results": results,
"total_texts_processed": len(request.texts),
"processing_time": f"{processing_time:.2f} seconds",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |