Syamchand commited on
Commit
8fcdf14
·
verified ·
1 Parent(s): a7a5f4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -352,28 +352,28 @@ def semantic_chunking(req: ChunkRequest):
352
 
353
 
354
 
 
355
  @app.post("/predict/explain")
356
  def explain_text(req: ExplanationRequest):
357
- """
358
- Summarize or explain the input text.
359
- mode = "summarize" → short summary
360
- mode = "explain" → detailed plain‑English explanation
361
- """
362
  tokenizer = models["explain_tokenizer"]
363
  model = models["explain_model"]
364
 
365
- if req.mode == "summarize":
366
- instruction = f"Summarize the following contract clause in 1-2 sentences:\n\n{req.text}"
367
- else:
368
- instruction = f"Explain the following contract clause in plain English, in detail:\n\n{req.text}"
 
 
369
 
370
- inputs = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=512)
 
371
  with torch.no_grad():
372
  outputs = model.generate(
373
  **inputs,
374
  max_new_tokens=150,
375
- do_sample=False,
376
- num_beams=4,
 
377
  early_stopping=True
378
  )
379
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -382,7 +382,6 @@ def explain_text(req: ExplanationRequest):
382
 
383
 
384
 
385
-
386
  @app.post("/predict/ner", response_model=NERResult)
387
  def predict_ner(req: NERRequest):
388
  # Default entity types suitable for freelancer contracts
 
352
 
353
 
354
 
355
+
356
  @app.post("/predict/explain")
357
  def explain_text(req: ExplanationRequest):
 
 
 
 
 
358
  tokenizer = models["explain_tokenizer"]
359
  model = models["explain_model"]
360
 
361
+ # FLAN-T5 models fine-tuned on summarization require the "summarize: " prefix
362
+ input_text = f"summarize: {req.text}"
363
+
364
+ # If the user asks for an 'explain', we can still frame it as an intensive summary
365
+ if req.mode == "explain":
366
+ input_text = f"summarize in detail: {req.text}"
367
 
368
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
369
+
370
  with torch.no_grad():
371
  outputs = model.generate(
372
  **inputs,
373
  max_new_tokens=150,
374
+ num_beams=5,
375
+ length_penalty=2.0, # Encourage longer generation
376
+ no_repeat_ngram_size=3, # Prevent repetition
377
  early_stopping=True
378
  )
379
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
382
 
383
 
384
 
 
385
  @app.post("/predict/ner", response_model=NERResult)
386
  def predict_ner(req: NERRequest):
387
  # Default entity types suitable for freelancer contracts