Aadityaramrame commited on
Commit
59db687
·
verified ·
1 Parent(s): 6a554e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -53
app.py CHANGED
@@ -1,46 +1,23 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from transformers import BartTokenizer, BartForConditionalGeneration
4
  import gradio as gr
 
5
  import torch
6
 
7
- # -------------------------------
8
- # FASTAPI SETUP
9
- # -------------------------------
10
- app = FastAPI(title="Jan Arogya Summarizer API")
11
-
12
- # Allow CORS for frontend access
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"],
16
- allow_credentials=True,
17
- allow_methods=["*"],
18
- allow_headers=["*"],
19
- )
20
-
21
  # -------------------------------
22
  # MODEL LOADING
23
  # -------------------------------
24
- model_name = "facebook/bart-large-cnn"
25
  tokenizer = BartTokenizer.from_pretrained(model_name)
26
  model = BartForConditionalGeneration.from_pretrained(model_name)
27
 
28
  # -------------------------------
29
- # FASTAPI ROUTES (API MODE)
30
  # -------------------------------
31
- @app.get("/")
32
- async def root():
33
- return {"message": "🚀 Jan Arogya Summarizer API is live!"}
34
 
35
- @app.post("/summarize")
36
- async def summarize(request: Request):
37
- data = await request.json()
38
- text = data.get("text", "")
39
- if not text:
40
- return {"error": "No text provided."}
41
 
42
- # Summarization process
43
- inputs = tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
44
  summary_ids = model.generate(
45
  inputs["input_ids"],
46
  num_beams=4,
@@ -48,34 +25,28 @@ async def summarize(request: Request):
48
  max_length=200,
49
  early_stopping=True
50
  )
 
51
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
52
- return {"summary": summary}
53
 
54
  # -------------------------------
55
- # GRADIO INTERFACE (Frontend)
56
  # -------------------------------
57
- def summarize_text(input_text):
58
- if not input_text.strip():
59
- return "⚠️ Please enter some text to summarize."
60
- inputs = tokenizer([input_text], max_length=1024, return_tensors="pt", truncation=True)
61
- summary_ids = model.generate(
62
- inputs["input_ids"],
63
- num_beams=4,
64
- min_length=50,
65
- max_length=200,
66
- early_stopping=True
67
- )
68
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
69
-
70
- gradio_interface = gr.Interface(
71
  fn=summarize_text,
72
- inputs=gr.Textbox(lines=10, label="Enter Medical Text"),
73
- outputs=gr.Textbox(label="Generated Summary", lines=10),
74
  title="🩺 Jan Arogya Summarizer",
75
- description="Summarize long medical text using the facebook/bart-large-cnn model.",
76
- theme="soft"
 
 
 
 
77
  )
78
 
79
- # Mount Gradio on FastAPI app
80
- app = gr.mount_gradio_app(app, gradio_interface, path="/")
81
-
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import BartTokenizer, BartForConditionalGeneration
3
  import torch
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # -------------------------------
6
  # MODEL LOADING
7
  # -------------------------------
8
+ model_name = "facebook/bart-large-cnn" # Hugging Face model
9
  tokenizer = BartTokenizer.from_pretrained(model_name)
10
  model = BartForConditionalGeneration.from_pretrained(model_name)
11
 
12
  # -------------------------------
13
+ # SUMMARIZATION FUNCTION
14
  # -------------------------------
15
+ def summarize_text(input_text):
16
+ if not input_text.strip():
17
+ return "⚠️ Please enter some text to summarize."
18
 
19
+ inputs = tokenizer([input_text], max_length=1024, return_tensors="pt", truncation=True)
 
 
 
 
 
20
 
 
 
21
  summary_ids = model.generate(
22
  inputs["input_ids"],
23
  num_beams=4,
 
25
  max_length=200,
26
  early_stopping=True
27
  )
28
+
29
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
30
+ return summary
31
 
32
  # -------------------------------
33
+ # GRADIO INTERFACE
34
  # -------------------------------
35
+ demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  fn=summarize_text,
37
+ inputs=gr.Textbox(lines=12, placeholder="Paste medical or long text here...", label="Input Text"),
38
+ outputs=gr.Textbox(lines=10, label="Generated Summary"),
39
  title="🩺 Jan Arogya Summarizer",
40
+ description="Summarize long medical or research text using the **facebook/bart-large-cnn** model.",
41
+ theme="soft",
42
+ examples=[
43
+ ["COVID-19 is a respiratory disease caused by the SARS-CoV-2 virus. It spread rapidly across the globe..."],
44
+ ["Hypertension is a chronic medical condition characterized by persistently high blood pressure levels..."]
45
+ ],
46
  )
47
 
48
+ # -------------------------------
49
+ # LAUNCH
50
+ # -------------------------------
51
+ if __name__ == "__main__":
52
+ demo.launch()