File size: 4,611 Bytes
ea9802b
 
 
 
 
01807b0
ea9802b
99838a4
 
 
 
e7739fb
ea9802b
e7739fb
99838a4
 
ea9802b
 
99838a4
ea9802b
ef0e5c3
01807b0
 
 
 
 
 
 
 
ef0e5c3
ea9802b
 
defb5c4
ea9802b
ef0e5c3
ea9802b
 
defb5c4
ea9802b
 
 
 
 
 
 
defb5c4
ea9802b
defb5c4
 
01807b0
 
 
 
 
defb5c4
ea9802b
19db783
488ca19
16297d9
ea9802b
 
01807b0
19db783
ea9802b
19db783
defb5c4
16297d9
488ca19
 
ea9802b
488ca19
ea9802b
defb5c4
01807b0
ea9802b
 
488ca19
defb5c4
ea9802b
 
01807b0
ea9802b
 
 
defb5c4
ea9802b
defb5c4
 
ea9802b
 
 
defb5c4
ea9802b
defb5c4
ea9802b
 
 
 
 
 
 
 
defb5c4
 
ea9802b
 
 
 
 
01807b0
ea9802b
defb5c4
 
ea9802b
 
 
 
16297d9
defb5c4
ea9802b
defb5c4
ea9802b
 
 
 
 
defb5c4
 
ea9802b
 
 
 
 
 
 
 
 
 
 
 
 
01807b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from fastapi import FastAPI, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import time
#from dotenv import load_dotenv
import os

#load_dotenv()

app = FastAPI()

API_KEY = os.getenv("API_KEY")

# Configuration
API_KEYS = {
    API_KEY : "user1"  # In production, use a secure database
}

# Initialize model and tokenizer with smaller model for Spaces
MODEL_NAME = "tuner007/pegasus_paraphrase"
print("Loading model and tokenizer...")
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
device = "cpu"  # Force CPU for Spaces deployment
model = model.to(device)
print("Model and tokenizer loaded successfully!")

class TextRequest(BaseModel):
    text: str
    style: Optional[str] = "standard"
    num_variations: Optional[int] = 1

class BatchRequest(BaseModel):
    texts: List[str]
    style: Optional[str] = "standard"
    num_variations: Optional[int] = 1

async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
    if api_key not in API_KEYS:
        raise HTTPException(status_code=403, detail="Invalid API key")
    return api_key

def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
    try:
        # Get parameters based on style
        params = {
            "standard": {"temperature": 1.5, "top_k": 80},
            "formal": {"temperature": 1.0, "top_k": 50},
            "casual": {"temperature": 1.6, "top_k": 100},
            "creative": {"temperature": 2.8, "top_k": 170},
        }.get(style, {"temperature": 1.0, "top_k": 50})
        
        # Tokenize the input text
        inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device)
        
        # Generate paraphrases
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=10000,
                num_return_sequences=num_variations,
                num_beams=10,
                temperature=params["temperature"],
                top_k=params["top_k"],
                do_sample=True,
                early_stopping=True,
            )
        
        # Decode the generated outputs
        paraphrases = [
            tokenizer.decode(output, skip_special_tokens=True) 
            for output in outputs
        ]
        
        return paraphrases

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")

@app.get("/")
async def root():
    return {"message": "Paraphrase API is running. Use /docs for API documentation."}

@app.post("/api/paraphrase")
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
    try:
        start_time = time.time()
        
        paraphrases = generate_paraphrase(
            request.text, 
            request.style, 
            request.num_variations
        )
        
        processing_time = time.time() - start_time
        
        return {
            "status": "success",
            "original_text": request.text,
            "paraphrased_texts": paraphrases,
            "style": request.style,
            "processing_time": f"{processing_time:.2f} seconds",
            "timestamp": datetime.now().isoformat()
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/batch-paraphrase")
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
    try:
        start_time = time.time()
        results = []
        
        for text in request.texts:
            paraphrases = generate_paraphrase(
                text, 
                request.style, 
                request.num_variations
            )
            
            results.append({
                "original_text": text,
                "paraphrased_texts": paraphrases,
                "style": request.style
            })
        
        processing_time = time.time() - start_time
        
        return {
            "status": "success",
            "results": results,
            "total_texts_processed": len(request.texts),
            "processing_time": f"{processing_time:.2f} seconds",
            "timestamp": datetime.now().isoformat()
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))