Spaces:
Sleeping
Sleeping
| import nltk | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, pipeline | |
| from rake_nltk import Rake | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| print("Loading Models...") | |
| try: | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| nltk.download('punkt_tab') | |
| # 1. TITLE MODEL (Keep this, or swap to 't5-small' if still desperate) | |
| t5_tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline", use_fast=False) | |
| title_pipe = pipeline("text2text-generation", model="Michau/t5-base-en-generate-headline", tokenizer=t5_tokenizer) | |
| # 2. DESCRIPTION MODEL (>>> CHANGED TO DISTILBART <<<) | |
| # This model is 3x faster and smaller than bart-large-cnn | |
| bart_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6", use_fast=False) | |
| desc_pipe = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", tokenizer=bart_tokenizer) | |
| print("Models Ready!") | |
| class VideoInput(BaseModel): | |
| text: str | |
| def get_tags(text): | |
| rake = Rake() | |
| rake.extract_keywords_from_text(text) | |
| phrases = rake.get_ranked_phrases()[:5] | |
| hashtags = ["#" + p.replace(" ", "") for p in phrases] | |
| tags = [p.replace(" ", "") for p in phrases] | |
| return hashtags, tags | |
| def home(): | |
| return {"status": "API is running."} | |
| async def generate(payload: VideoInput): | |
| text = payload.text | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Empty text") | |
| # Lower max_new_tokens slightly to speed up generation | |
| title_out = title_pipe("headline: " + text, max_new_tokens=50, do_sample=False) | |
| desc_out = desc_pipe(text, max_new_tokens=100, do_sample=False) | |
| hashtags, tags = get_tags(text) | |
| return { | |
| "title": title_out[0]["generated_text"], | |
| "description": desc_out[0]["summary_text"], | |
| "hashtags": hashtags, | |
| "tags": tags | |
| } |