lahiruchamika27 commited on
Commit
16297d9
·
verified ·
1 Parent(s): 692ae11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -72
app.py CHANGED
@@ -3,10 +3,10 @@ from pydantic import BaseModel
3
  from typing import Optional, List
4
  from datetime import datetime
5
  import torch
6
- from transformers import BartForConditionalGeneration, BartTokenizer
7
  import time
8
  import traceback
9
  import logging
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
@@ -19,12 +19,12 @@ API_KEYS = {
19
  "bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
20
  }
21
 
22
- # Initialize model and tokenizer
23
- MODEL_NAME = "facebook/bart-large-cnn"
24
  try:
25
  print("Loading model and tokenizer...")
26
- tokenizer = BartTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
27
- model = BartForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = model.to(device)
30
  print(f"Model and tokenizer loaded successfully on {device}!")
@@ -32,7 +32,6 @@ except Exception as e:
32
  error_msg = f"Error loading model: {str(e)}\n{traceback.format_exc()}"
33
  print(error_msg)
34
  logger.error(error_msg)
35
- # Continue without crashing, we'll handle this in the endpoints
36
 
37
  class TextRequest(BaseModel):
38
  text: str
@@ -51,37 +50,30 @@ async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
51
 
52
  def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
53
  try:
54
- # Check if model was loaded successfully
55
- if 'model' not in globals() or model is None:
56
- raise Exception("Model failed to load. Check server logs.")
57
-
58
  # Get parameters based on style
59
  params = {
60
- "standard": {"temperature": 1.0, "top_p": 0.9},
61
- "formal": {"temperature": 0.7, "top_p": 0.8},
62
- "casual": {"temperature": 1.3, "top_p": 0.95},
63
- "creative": {"temperature": 1.8, "top_p": 0.99},
64
- }.get(style, {"temperature": 1.0, "top_p": 0.9})
65
-
66
- logger.info(f"Processing text: {text[:50]}... with style {style}")
67
 
68
  # Tokenize the input text
69
- inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
70
- logger.info(f"Input tokenized successfully, shape: {inputs.input_ids.shape}")
71
 
72
- # Generate paraphrases with simplified parameters
73
  with torch.no_grad():
74
  outputs = model.generate(
75
- input_ids=inputs.input_ids,
76
- attention_mask=inputs.attention_mask,
77
- max_length=100, # Reduced max length
78
  num_return_sequences=num_variations,
79
- num_beams=4, # Simplified beam search
80
  temperature=params["temperature"],
81
- do_sample=True,
 
 
 
82
  )
83
-
84
- logger.info(f"Generation completed, output shape: {outputs.shape}")
85
 
86
  # Decode the generated outputs
87
  paraphrases = [
@@ -89,7 +81,6 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
89
  for output in outputs
90
  ]
91
 
92
- logger.info(f"Paraphrases decoded successfully: {len(paraphrases)} variations")
93
  return paraphrases
94
 
95
  except Exception as e:
@@ -104,7 +95,6 @@ async def root():
104
  @app.post("/api/paraphrase")
105
  async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
106
  try:
107
- logger.info(f"Received paraphrase request with style: {request.style}")
108
  start_time = time.time()
109
 
110
  paraphrases = generate_paraphrase(
@@ -114,7 +104,6 @@ async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key
114
  )
115
 
116
  processing_time = time.time() - start_time
117
- logger.info(f"Request processed in {processing_time:.2f} seconds")
118
 
119
  return {
120
  "status": "success",
@@ -126,19 +115,17 @@ async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key
126
  }
127
 
128
  except Exception as e:
129
- error_msg = f"API error: {str(e)}\n{traceback.format_exc()}"
130
- logger.error(error_msg)
131
  raise HTTPException(status_code=500, detail=error_msg)
132
 
133
  @app.post("/api/batch-paraphrase")
134
  async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
135
  try:
136
- logger.info(f"Received batch paraphrase request for {len(request.texts)} texts")
137
  start_time = time.time()
138
  results = []
139
 
140
- for i, text in enumerate(request.texts):
141
- logger.info(f"Processing batch item {i+1}/{len(request.texts)}")
142
  paraphrases = generate_paraphrase(
143
  text,
144
  request.style,
@@ -152,7 +139,6 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
152
  })
153
 
154
  processing_time = time.time() - start_time
155
- logger.info(f"Batch request processed in {processing_time:.2f} seconds")
156
 
157
  return {
158
  "status": "success",
@@ -163,49 +149,26 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
163
  }
164
 
165
  except Exception as e:
166
- error_msg = f"API error: {str(e)}\n{traceback.format_exc()}"
167
- logger.error(error_msg)
168
  raise HTTPException(status_code=500, detail=error_msg)
169
 
170
- # Add an endpoint for debugging
171
- @app.get("/api/debug")
172
- async def debug_info():
173
  try:
174
- model_info = {
175
- "model_name": MODEL_NAME,
176
- "device": str(device),
177
- "model_loaded": 'model' in globals() and model is not None,
178
- "tokenizer_loaded": 'tokenizer' in globals() and tokenizer is not None,
179
- }
180
-
181
- # Test tokenization
182
- test_text = "This is a test."
183
- tokenization_test = {}
184
- try:
185
- tokens = tokenizer(test_text, return_tensors="pt")
186
- tokenization_test = {
187
- "success": True,
188
- "input_shape": tokens.input_ids.shape,
189
- "tokens": tokens.input_ids.tolist()
190
- }
191
- except Exception as e:
192
- tokenization_test = {
193
- "success": False,
194
- "error": str(e)
195
- }
196
-
197
  return {
198
- "status": "debug info",
199
- "model_info": model_info,
200
- "tokenization_test": tokenization_test,
201
- "torch_version": torch.__version__,
202
- "api_keys_configured": len(API_KEYS)
203
  }
204
  except Exception as e:
205
  return {
206
  "status": "error",
207
  "error": str(e),
208
  "traceback": traceback.format_exc()
209
- }
210
-
211
-
 
3
  from typing import Optional, List
4
  from datetime import datetime
5
  import torch
 
6
  import time
7
  import traceback
8
  import logging
9
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
 
19
  "bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
20
  }
21
 
22
+ # Initialize model and tokenizer - using a dedicated paraphrasing model
23
+ MODEL_NAME = "tuner007/pegasus_paraphrase" # This model is specifically for paraphrasing
24
  try:
25
  print("Loading model and tokenizer...")
26
+ tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
27
+ model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = model.to(device)
30
  print(f"Model and tokenizer loaded successfully on {device}!")
 
32
  error_msg = f"Error loading model: {str(e)}\n{traceback.format_exc()}"
33
  print(error_msg)
34
  logger.error(error_msg)
 
35
 
36
  class TextRequest(BaseModel):
37
  text: str
 
50
 
51
  def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
52
  try:
 
 
 
 
53
  # Get parameters based on style
54
  params = {
55
+ "standard": {"temperature": 1.0, "top_k": 50, "diversity_penalty": 1.0},
56
+ "formal": {"temperature": 0.7, "top_k": 40, "diversity_penalty": 1.0},
57
+ "casual": {"temperature": 1.3, "top_k": 70, "diversity_penalty": 0.8},
58
+ "creative": {"temperature": 1.5, "top_k": 100, "diversity_penalty": 0.7},
59
+ }.get(style, {"temperature": 1.0, "top_k": 50, "diversity_penalty": 1.0})
 
 
60
 
61
  # Tokenize the input text
62
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
 
63
 
64
+ # Generate paraphrases
65
  with torch.no_grad():
66
  outputs = model.generate(
67
+ input_ids,
68
+ max_length=128,
 
69
  num_return_sequences=num_variations,
70
+ num_beams=num_variations + 2,
71
  temperature=params["temperature"],
72
+ top_k=params["top_k"],
73
+ diversity_penalty=params["diversity_penalty"],
74
+ num_beam_groups=min(num_variations, 4) if num_variations > 1 else 1,
75
+ do_sample=True
76
  )
 
 
77
 
78
  # Decode the generated outputs
79
  paraphrases = [
 
81
  for output in outputs
82
  ]
83
 
 
84
  return paraphrases
85
 
86
  except Exception as e:
 
95
  @app.post("/api/paraphrase")
96
  async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
97
  try:
 
98
  start_time = time.time()
99
 
100
  paraphrases = generate_paraphrase(
 
104
  )
105
 
106
  processing_time = time.time() - start_time
 
107
 
108
  return {
109
  "status": "success",
 
115
  }
116
 
117
  except Exception as e:
118
+ error_msg = f"API error: {str(e)}"
119
+ logger.error(f"{error_msg}\n{traceback.format_exc()}")
120
  raise HTTPException(status_code=500, detail=error_msg)
121
 
122
  @app.post("/api/batch-paraphrase")
123
  async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
124
  try:
 
125
  start_time = time.time()
126
  results = []
127
 
128
+ for text in request.texts:
 
129
  paraphrases = generate_paraphrase(
130
  text,
131
  request.style,
 
139
  })
140
 
141
  processing_time = time.time() - start_time
 
142
 
143
  return {
144
  "status": "success",
 
149
  }
150
 
151
  except Exception as e:
152
+ error_msg = f"API error: {str(e)}"
153
+ logger.error(f"{error_msg}\n{traceback.format_exc()}")
154
  raise HTTPException(status_code=500, detail=error_msg)
155
 
156
+ # For testing/debugging the API
157
+ @app.get("/api/test")
158
+ async def test_endpoint():
159
  try:
160
+ test_text = "The quick brown fox jumps over the lazy dog."
161
+ result = generate_paraphrase(test_text, "standard", 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return {
163
+ "status": "success",
164
+ "test_text": test_text,
165
+ "paraphrased": result,
166
+ "model": MODEL_NAME,
167
+ "device": device
168
  }
169
  except Exception as e:
170
  return {
171
  "status": "error",
172
  "error": str(e),
173
  "traceback": traceback.format_exc()
174
+ }