pius-code commited on
Commit
575c139
·
1 Parent(s): 3136230

implement dynamic length adjustments for summarization and add translation endpoint

Browse files
Files changed (1) hide show
  1. main.py +51 -3
main.py CHANGED
@@ -1,6 +1,11 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
 
 
 
 
 
4
 
5
  app = FastAPI()
6
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
@@ -16,5 +21,48 @@ async def root():
16
 
17
  @app.post("/summarize")
18
  async def summarize_text(input: TextInput):
19
- summary = (summarizer(input.text, max_length=130, min_length=30, do_sample=False ))
20
- return {"summary": summary[0]['summary_text']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import pipeline, AutoTokenizer,T5ForConditionalGeneration
4
+
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
7
+ model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-base")
8
+
9
 
10
  app = FastAPI()
11
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
21
 
22
  @app.post("/summarize")
23
  async def summarize_text(input: TextInput):
24
+ # Count approximate number of words (could be improved with tokenizer count)
25
+ word_count = len(input.text.split())
26
+
27
+ # Set dynamic parameters based on input length
28
+ if word_count < 50:
29
+ max_length = max(10, word_count // 2) # Half the original length, minimum 10
30
+ min_length = max(3, word_count // 4) # Quarter the original length, minimum 3
31
+ elif word_count < 200:
32
+ max_length = max(50, word_count // 3)
33
+ min_length = max(15, word_count // 6)
34
+ else:
35
+ max_length = max(100, word_count // 4)
36
+ min_length = max(30, word_count // 8)
37
+
38
+ # Prevent max_length from being too large (BART has token limits)
39
+ max_length = min(max_length, 1024)
40
+
41
+ # Generate summary with dynamic parameters
42
+ summary = summarizer(
43
+ input.text,
44
+ max_length=max_length,
45
+ min_length=min_length,
46
+ do_sample=True,
47
+ temperature=0.7,
48
+ num_beams=4
49
+ )
50
+
51
+ return {
52
+ "summary": summary[0]["summary_text"],
53
+ "parameters_used": {
54
+ "input_word_count": word_count,
55
+ "max_length": max_length,
56
+ "min_length": min_length
57
+ }
58
+ }
59
+
60
+
61
+
62
+ @app.post("/translateFrench")
63
+ async def translate(input: TextInput):
64
+ input.text = "translate English to French: " + input.text
65
+ input_ids = tokenizer(input.text, return_tensors="pt").input_ids
66
+ output = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)
67
+ translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
68
+ return {"translated_text": translated_text}