Asanaly commited on
Commit
f06ba28
·
verified ·
1 Parent(s): abd27a2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -9
main.py CHANGED
@@ -1,29 +1,44 @@
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import BartForConditionalGeneration, BartTokenizer
4
  import torch
5
 
 
6
  app = FastAPI(title="Multilingual Text Summarizer")
7
 
8
- # Model for English and Russian summarization
9
- MODEL_NAME = "facebook/bart-large-cnn" # English
10
- tokenizer_en = BartTokenizer.from_pretrained(MODEL_NAME)
11
- model_en = BartForConditionalGeneration.from_pretrained(MODEL_NAME)
 
 
 
 
12
 
13
- MODEL_NAME_RU = "IlyaGusev/mbart_ru_sum_gazeta" # Russian
14
- tokenizer_ru = BartTokenizer.from_pretrained(MODEL_NAME_RU)
15
- model_ru = BartForConditionalGeneration.from_pretrained(MODEL_NAME_RU)
 
16
 
 
 
 
17
  class TextRequest(BaseModel):
18
  text: str
19
- lang: str # "en" or "ru"
20
 
 
 
 
21
  @app.get("/")
22
  def root():
23
  return {"message": "Multilingual Text Summarizer is running!"}
24
 
25
  @app.post("/summarize/")
26
  def summarize(request: TextRequest):
 
27
  if request.lang.lower() == "ru":
28
  tokenizer = tokenizer_ru
29
  model = model_ru
@@ -31,12 +46,16 @@ def summarize(request: TextRequest):
31
  tokenizer = tokenizer_en
32
  model = model_en
33
 
 
34
  inputs = tokenizer([request.text], max_length=1024, return_tensors="pt", truncation=True)
 
 
35
  summary_ids = model.generate(
36
  inputs["input_ids"],
37
  num_beams=4,
38
  max_length=150,
39
  early_stopping=True
40
  )
 
41
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
42
  return {"summary": summary}
 
1
+ # main.py
2
+
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
+ from transformers import BartForConditionalGeneration, BartTokenizer, MBartForConditionalGeneration, MBartTokenizer
6
  import torch
7
 
8
+ # FastAPI қосымшасы
9
  app = FastAPI(title="Multilingual Text Summarizer")
10
 
11
+ # ==========================
12
+ # Модельдерді жүктеу
13
+ # ==========================
14
+
15
+ # English BART
16
+ MODEL_NAME_EN = "facebook/bart-large-cnn"
17
+ tokenizer_en = BartTokenizer.from_pretrained(MODEL_NAME_EN)
18
+ model_en = BartForConditionalGeneration.from_pretrained(MODEL_NAME_EN)
19
 
20
+ # Russian MBart
21
+ MODEL_NAME_RU = "IlyaGusev/mbart_ru_sum_gazeta"
22
+ tokenizer_ru = MBartTokenizer.from_pretrained(MODEL_NAME_RU)
23
+ model_ru = MBartForConditionalGeneration.from_pretrained(MODEL_NAME_RU)
24
 
25
+ # ==========================
26
+ # API Request схемасы
27
+ # ==========================
28
  class TextRequest(BaseModel):
29
  text: str
30
+ lang: str # "en" немесе "ru"
31
 
32
+ # ==========================
33
+ # Routes
34
+ # ==========================
35
  @app.get("/")
36
  def root():
37
  return {"message": "Multilingual Text Summarizer is running!"}
38
 
39
  @app.post("/summarize/")
40
  def summarize(request: TextRequest):
41
+ # Тілді таңдау
42
  if request.lang.lower() == "ru":
43
  tokenizer = tokenizer_ru
44
  model = model_ru
 
46
  tokenizer = tokenizer_en
47
  model = model_en
48
 
49
+ # Tokenize
50
  inputs = tokenizer([request.text], max_length=1024, return_tensors="pt", truncation=True)
51
+
52
+ # Summarize
53
  summary_ids = model.generate(
54
  inputs["input_ids"],
55
  num_beams=4,
56
  max_length=150,
57
  early_stopping=True
58
  )
59
+
60
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
61
  return {"summary": summary}