lahiruchamika27 commited on
Commit
ef0e5c3
·
verified ·
1 Parent(s): 303983a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -7
app.py CHANGED
@@ -1,13 +1,149 @@
1
- from fastapi import FastAPI, HTTPException
2
- import uvicorn
 
 
 
 
 
 
3
 
4
- # Initialize FastAPI app
5
  app = FastAPI()
6
 
7
- @app.route('/')
8
- def index():
9
- return "Hello"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Run the app with Uvicorn (use this command in terminal: uvicorn your_script_name:app --reload)
12
  if __name__ == "__main__":
 
13
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Header, Depends
2
+ from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ from datetime import datetime
5
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
6
+ import torch
7
+ from typing import List
8
+ import time
9
 
 
10
  app = FastAPI()
11
 
12
+ # Configuration
13
+ API_KEYS = {
14
+ "your-secret-api-key": "user1" # In production, use a secure database
15
+ }
16
+
17
+ # Load model and tokenizer globally
18
+ MODEL_NAME = "tuner007/pegasus_paraphrase"
19
+ tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
20
+ model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ class TextRequest(BaseModel):
24
+ text: str
25
+ style: Optional[str] = "standard"
26
+ num_variations: Optional[int] = 1
27
+
28
+ class BatchRequest(BaseModel):
29
+ texts: List[str]
30
+ style: Optional[str] = "standard"
31
+ num_variations: Optional[int] = 1
32
+
33
+ def get_paraphrase_params(style: str):
34
+ """Get model parameters based on style"""
35
+ params = {
36
+ "standard": {"temperature": 1.0, "top_k": 50, "top_p": 0.95},
37
+ "formal": {"temperature": 0.7, "top_k": 30, "top_p": 0.9},
38
+ "casual": {"temperature": 1.3, "top_k": 100, "top_p": 0.95},
39
+ "creative": {"temperature": 1.5, "top_k": 120, "top_p": 0.99},
40
+ }
41
+ return params.get(style, params["standard"])
42
+
43
+ async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
44
+ if api_key not in API_KEYS:
45
+ raise HTTPException(status_code=403, detail="Invalid API key")
46
+ return api_key
47
+
48
+ def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
49
+ try:
50
+ # Get parameters based on style
51
+ params = get_paraphrase_params(style)
52
+
53
+ # Tokenize the input text
54
+ inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
55
+
56
+ # Generate paraphrases
57
+ with torch.no_grad():
58
+ outputs = model.generate(
59
+ **inputs,
60
+ max_length=60,
61
+ num_return_sequences=num_variations,
62
+ num_beams=num_variations * 2,
63
+ temperature=params["temperature"],
64
+ top_k=params["top_k"],
65
+ top_p=params["top_p"],
66
+ do_sample=True,
67
+ early_stopping=True,
68
+ )
69
+
70
+ # Decode the generated outputs
71
+ paraphrases = [
72
+ tokenizer.decode(output, skip_special_tokens=True)
73
+ for output in outputs
74
+ ]
75
+
76
+ return paraphrases
77
+
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
80
+
81
+ @app.post("/api/paraphrase")
82
+ async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
83
+ try:
84
+ start_time = time.time()
85
+
86
+ paraphrases = generate_paraphrase(
87
+ request.text,
88
+ request.style,
89
+ request.num_variations
90
+ )
91
+
92
+ processing_time = time.time() - start_time
93
+
94
+ return {
95
+ "status": "success",
96
+ "original_text": request.text,
97
+ "paraphrased_texts": paraphrases,
98
+ "style": request.style,
99
+ "processing_time": f"{processing_time:.2f} seconds",
100
+ "timestamp": datetime.now().isoformat()
101
+ }
102
+
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))
105
+
106
+ @app.post("/api/batch-paraphrase")
107
+ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
108
+ try:
109
+ start_time = time.time()
110
+ results = []
111
+
112
+ for text in request.texts:
113
+ paraphrases = generate_paraphrase(
114
+ text,
115
+ request.style,
116
+ request.num_variations
117
+ )
118
+
119
+ results.append({
120
+ "original_text": text,
121
+ "paraphrased_texts": paraphrases,
122
+ "style": request.style
123
+ })
124
+
125
+ processing_time = time.time() - start_time
126
+
127
+ return {
128
+ "status": "success",
129
+ "results": results,
130
+ "total_texts_processed": len(request.texts),
131
+ "processing_time": f"{processing_time:.2f} seconds",
132
+ "timestamp": datetime.now().isoformat()
133
+ }
134
+
135
+ except Exception as e:
136
+ raise HTTPException(status_code=500, detail=str(e))
137
+
138
+ @app.get("/api/health")
139
+ async def health_check(api_key: str = Depends(verify_api_key)):
140
+ return {
141
+ "status": "healthy",
142
+ "model": MODEL_NAME,
143
+ "device": device,
144
+ "timestamp": datetime.now().isoformat()
145
+ }
146
 
 
147
  if __name__ == "__main__":
148
+ import uvicorn
149
  uvicorn.run(app, host="0.0.0.0", port=8000)