lahiruchamika27 commited on
Commit
ea9802b
·
verified ·
1 Parent(s): b3ae7cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -11
app.py CHANGED
@@ -1,16 +1,134 @@
1
- from transformers import T5ForConditionalGeneration, T5Tokenizer
 
 
 
 
 
 
2
 
3
- model_name = "t5-base" # Change model as needed
4
- tokenizer = T5Tokenizer.from_pretrained(model_name)
5
- model = T5ForConditionalGeneration.from_pretrained(model_name)
6
 
7
- def paraphrase(text, num_variations=1, style="standard"):
8
- input_text = f"paraphrase: {text} </s>"
9
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
 
10
 
11
- outputs = model.generate(**inputs, max_length=150, num_return_sequences=num_variations, temperature=1.5, top_k=100)
12
- paraphrased_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
 
 
 
 
 
 
13
 
14
- return paraphrased_texts
 
 
 
 
15
 
16
- print(paraphrase("We are a company that uses our skills and creativity to assist businesses in expanding online."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Header, Depends
2
+ from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ from datetime import datetime
5
+ import torch
6
+ from transformers import BartForConditionalGeneration, BartTokenizer
7
+ import time
8
 
9
+ app = FastAPI()
 
 
10
 
11
+ # Configuration
12
+ API_KEYS = {
13
+ "bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
14
+ }
15
 
16
+ # Initialize model and tokenizer
17
+ MODEL_NAME = "facebook/bart-large-cnn"
18
+ print("Loading model and tokenizer...")
19
+ tokenizer = BartTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
20
+ model = BartForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model = model.to(device)
23
+ print(f"Model and tokenizer loaded successfully on {device}!")
24
 
25
+ class TextRequest(BaseModel):
26
+ text: str
27
+ max_length: Optional[int] = 150
28
+ min_length: Optional[int] = 40
29
+ num_variations: Optional[int] = 1
30
 
31
+ class BatchRequest(BaseModel):
32
+ texts: List[str]
33
+ max_length: Optional[int] = 150
34
+ min_length: Optional[int] = 40
35
+ num_variations: Optional[int] = 1
36
+
37
+ async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
38
+ if api_key not in API_KEYS:
39
+ raise HTTPException(status_code=403, detail="Invalid API key")
40
+ return api_key
41
+
42
+ def generate_summary(text: str, max_length: int = 150, min_length: int = 40, num_variations: int = 1) -> List[str]:
43
+ try:
44
+ # Tokenize the input text
45
+ inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
46
+
47
+ # Generate summaries
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ **inputs,
51
+ max_length=max_length,
52
+ min_length=min_length,
53
+ num_return_sequences=num_variations,
54
+ num_beams=num_variations * 2,
55
+ early_stopping=True,
56
+ diversity_penalty=0.5 if num_variations > 1 else 0.0,
57
+ num_beam_groups=num_variations if num_variations > 1 else 1
58
+ )
59
+
60
+ # Decode the generated outputs
61
+ summaries = [
62
+ tokenizer.decode(output, skip_special_tokens=True)
63
+ for output in outputs
64
+ ]
65
+
66
+ return summaries
67
+
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Summary generation error: {str(e)}")
70
+
71
+ @app.get("/")
72
+ async def root():
73
+ return {"message": "Summarization API is running. Use /docs for API documentation."}
74
+
75
+ @app.post("/api/summarize")
76
+ async def summarize(request: TextRequest, api_key: str = Depends(verify_api_key)):
77
+ try:
78
+ start_time = time.time()
79
+
80
+ summaries = generate_summary(
81
+ request.text,
82
+ request.max_length,
83
+ request.min_length,
84
+ request.num_variations
85
+ )
86
+
87
+ processing_time = time.time() - start_time
88
+
89
+ return {
90
+ "status": "success",
91
+ "original_text": request.text,
92
+ "summarized_texts": summaries,
93
+ "max_length": request.max_length,
94
+ "min_length": request.min_length,
95
+ "processing_time": f"{processing_time:.2f} seconds",
96
+ "timestamp": datetime.now().isoformat()
97
+ }
98
+
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=str(e))
101
+
102
+ @app.post("/api/batch-summarize")
103
+ async def batch_summarize(request: BatchRequest, api_key: str = Depends(verify_api_key)):
104
+ try:
105
+ start_time = time.time()
106
+ results = []
107
+
108
+ for text in request.texts:
109
+ summaries = generate_summary(
110
+ text,
111
+ request.max_length,
112
+ request.min_length,
113
+ request.num_variations
114
+ )
115
+
116
+ results.append({
117
+ "original_text": text,
118
+ "summarized_texts": summaries,
119
+ "max_length": request.max_length,
120
+ "min_length": request.min_length
121
+ })
122
+
123
+ processing_time = time.time() - start_time
124
+
125
+ return {
126
+ "status": "success",
127
+ "results": results,
128
+ "total_texts_processed": len(request.texts),
129
+ "processing_time": f"{processing_time:.2f} seconds",
130
+ "timestamp": datetime.now().isoformat()
131
+ }
132
+
133
+ except Exception as e:
134
+ raise HTTPException(status_code=500, detail=str(e))