SujalChhajed925 commited on
Commit
299265a
·
verified ·
1 Parent(s): d29b20b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -71
main.py CHANGED
@@ -1,72 +1,68 @@
1
- import nltk
2
- from fastapi import FastAPI, HTTPException
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
- from transformers import AutoTokenizer, pipeline
6
- from rake_nltk import Rake
7
-
8
- # --- Setup & Model Loading ---
9
- app = FastAPI()
10
-
11
- # CRITICAL: Allow your Frontend to access this API
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"], # Allows all origins (Safe for public free APIs)
15
- allow_credentials=True,
16
- allow_methods=["*"],
17
- allow_headers=["*"],
18
- )
19
-
20
- print("Loading Models...")
21
- # Download NLTK data
22
- try:
23
- nltk.data.find('tokenizers/punkt_tab')
24
- except LookupError:
25
- nltk.download('stopwords')
26
- nltk.download('punkt_tab')
27
-
28
- # Load AI Models (Cached in memory)
29
- t5_tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline", use_fast=False)
30
- title_pipe = pipeline("text2text-generation", model="Michau/t5-base-en-generate-headline", tokenizer=t5_tokenizer)
31
-
32
- bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn", use_fast=False)
33
- desc_pipe = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer=bart_tokenizer)
34
- print("Models Ready!")
35
-
36
- # --- Logic ---
37
- class VideoInput(BaseModel):
38
- text: str
39
-
40
- def get_tags(text):
41
- rake = Rake()
42
- rake.extract_keywords_from_text(text)
43
- phrases = rake.get_ranked_phrases()[:5]
44
- hashtags = ["#" + p.replace(" ", "") for p in phrases]
45
- tags = [p.replace(" ", "") for p in phrases]
46
- return hashtags, tags
47
-
48
- @app.get("/")
49
- def home():
50
- return {"status": "API is running. POST to /generate"}
51
-
52
- @app.post("/generate")
53
- async def generate(payload: VideoInput):
54
- text = payload.text
55
- if not text.strip():
56
- raise HTTPException(status_code=400, detail="Empty text")
57
-
58
- # 1. Generate Title
59
- title_out = title_pipe("headline: " + text, max_new_tokens=70, do_sample=False)
60
-
61
- # 2. Generate Description
62
- desc_out = desc_pipe(text, max_new_tokens=150, do_sample=False)
63
-
64
- # 3. Get Tags
65
- hashtags, tags = get_tags(text)
66
-
67
- return {
68
- "title": title_out[0]["generated_text"],
69
- "description": desc_out[0]["summary_text"],
70
- "hashtags": hashtags,
71
- "tags": tags
72
  }
 
1
+ import nltk
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, pipeline
6
+ from rake_nltk import Rake
7
+
8
+ app = FastAPI()
9
+
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ print("Loading Models...")
19
+ try:
20
+ nltk.data.find('tokenizers/punkt_tab')
21
+ except LookupError:
22
+ nltk.download('stopwords')
23
+ nltk.download('punkt_tab')
24
+
25
+ # 1. TITLE MODEL (Keep this, or swap to 't5-small' if still desperate)
26
+ t5_tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline", use_fast=False)
27
+ title_pipe = pipeline("text2text-generation", model="Michau/t5-base-en-generate-headline", tokenizer=t5_tokenizer)
28
+
29
+ # 2. DESCRIPTION MODEL (>>> CHANGED TO DISTILBART <<<)
30
+ # This model is 3x faster and smaller than bart-large-cnn
31
+ bart_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6", use_fast=False)
32
+ desc_pipe = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", tokenizer=bart_tokenizer)
33
+
34
+ print("Models Ready!")
35
+
36
+ class VideoInput(BaseModel):
37
+ text: str
38
+
39
+ def get_tags(text):
40
+ rake = Rake()
41
+ rake.extract_keywords_from_text(text)
42
+ phrases = rake.get_ranked_phrases()[:5]
43
+ hashtags = ["#" + p.replace(" ", "") for p in phrases]
44
+ tags = [p.replace(" ", "") for p in phrases]
45
+ return hashtags, tags
46
+
47
+ @app.get("/")
48
+ def home():
49
+ return {"status": "API is running."}
50
+
51
+ @app.post("/generate")
52
+ async def generate(payload: VideoInput):
53
+ text = payload.text
54
+ if not text.strip():
55
+ raise HTTPException(status_code=400, detail="Empty text")
56
+
57
+ # Lower max_new_tokens slightly to speed up generation
58
+ title_out = title_pipe("headline: " + text, max_new_tokens=50, do_sample=False)
59
+ desc_out = desc_pipe(text, max_new_tokens=100, do_sample=False)
60
+
61
+ hashtags, tags = get_tags(text)
62
+
63
+ return {
64
+ "title": title_out[0]["generated_text"],
65
+ "description": desc_out[0]["summary_text"],
66
+ "hashtags": hashtags,
67
+ "tags": tags
 
 
 
 
68
  }