Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -352,28 +352,28 @@ def semantic_chunking(req: ChunkRequest):
|
|
| 352 |
|
| 353 |
|
| 354 |
|
|
|
|
| 355 |
@app.post("/predict/explain")
|
| 356 |
def explain_text(req: ExplanationRequest):
|
| 357 |
-
"""
|
| 358 |
-
Summarize or explain the input text.
|
| 359 |
-
mode = "summarize" → short summary
|
| 360 |
-
mode = "explain" → detailed plain‑English explanation
|
| 361 |
-
"""
|
| 362 |
tokenizer = models["explain_tokenizer"]
|
| 363 |
model = models["explain_model"]
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
inputs = tokenizer(
|
|
|
|
| 371 |
with torch.no_grad():
|
| 372 |
outputs = model.generate(
|
| 373 |
**inputs,
|
| 374 |
max_new_tokens=150,
|
| 375 |
-
|
| 376 |
-
|
|
|
|
| 377 |
early_stopping=True
|
| 378 |
)
|
| 379 |
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
@@ -382,7 +382,6 @@ def explain_text(req: ExplanationRequest):
|
|
| 382 |
|
| 383 |
|
| 384 |
|
| 385 |
-
|
| 386 |
@app.post("/predict/ner", response_model=NERResult)
|
| 387 |
def predict_ner(req: NERRequest):
|
| 388 |
# Default entity types suitable for freelancer contracts
|
|
|
|
| 352 |
|
| 353 |
|
| 354 |
|
| 355 |
+
|
| 356 |
@app.post("/predict/explain")
|
| 357 |
def explain_text(req: ExplanationRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
tokenizer = models["explain_tokenizer"]
|
| 359 |
model = models["explain_model"]
|
| 360 |
|
| 361 |
+
# FLAN-T5 models fine-tuned on summarization require the "summarize: " prefix
|
| 362 |
+
input_text = f"summarize: {req.text}"
|
| 363 |
+
|
| 364 |
+
# If the user asks for an 'explain', we can still frame it as an intensive summary
|
| 365 |
+
if req.mode == "explain":
|
| 366 |
+
input_text = f"summarize in detail: {req.text}"
|
| 367 |
|
| 368 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
|
| 369 |
+
|
| 370 |
with torch.no_grad():
|
| 371 |
outputs = model.generate(
|
| 372 |
**inputs,
|
| 373 |
max_new_tokens=150,
|
| 374 |
+
num_beams=5,
|
| 375 |
+
length_penalty=2.0, # Encourage longer generation
|
| 376 |
+
no_repeat_ngram_size=3, # Prevent repetition
|
| 377 |
early_stopping=True
|
| 378 |
)
|
| 379 |
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 382 |
|
| 383 |
|
| 384 |
|
|
|
|
| 385 |
@app.post("/predict/ner", response_model=NERResult)
|
| 386 |
def predict_ner(req: NERRequest):
|
| 387 |
# Default entity types suitable for freelancer contracts
|