Adive01 commited on
Commit
30653a0
Β·
verified Β·
1 Parent(s): 4325ad2

Upload mlplo/api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/api.py +294 -0
mlplo/api.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ from contextlib import asynccontextmanager
5
+ from typing import List, Tuple
6
+
7
+ import torch
8
+ from fastapi import FastAPI, HTTPException
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ from transformers import AutoModelForSeq2SeqLM
13
+
14
+ from google import genai
15
+
16
+ from .common import (
17
+ DEFAULT_APP_FALLBACK_MODEL,
18
+ DEFAULT_INPUT_MAX_LENGTH,
19
+ default_device,
20
+ existing_default_checkpoint,
21
+ load_tokenizer,
22
+ normalize_text,
23
+ resolve_model_reference,
24
+ )
25
+
26
+ LOGGER = logging.getLogger(__name__)
27
+
28
+ ml_context = {}
29
+
30
+ # ── Gemini setup ──────────────────────────────────────────────────────────────
31
+ from dotenv import load_dotenv
32
+ load_dotenv()
33
+
34
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
35
+ gemini_client = None
36
+ if GEMINI_API_KEY:
37
+ gemini_client = genai.Client(api_key=GEMINI_API_KEY)
38
+ LOGGER.info("Gemini API key loaded β€” client ready.")
39
+ else:
40
+ LOGGER.warning("GEMINI_API_KEY not set β€” Gemini features will be unavailable.")
41
+
42
+ # ── Chunking constants ────────────────────────────────────────────────────────
43
+ CHUNK_SIZE = 850 # tokens per chunk (well within BART's 1024 limit)
44
+
45
+
46
+ @asynccontextmanager
47
+ async def lifespan(app: FastAPI):
48
+ model_path = existing_default_checkpoint()
49
+ model_reference = resolve_model_reference(model_path, fallback=DEFAULT_APP_FALLBACK_MODEL)
50
+ device = default_device()
51
+
52
+ LOGGER.info(f"Loading BART model from {model_reference}")
53
+ tokenizer = load_tokenizer(model_reference)
54
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_reference)
55
+
56
+ if getattr(model.generation_config, "max_length", None) == 20:
57
+ model.generation_config.max_length = None
58
+
59
+ model.to(device)
60
+ model.eval()
61
+
62
+ ml_context["model"] = model
63
+ ml_context["tokenizer"] = tokenizer
64
+ ml_context["device"] = device
65
+ ml_context["max_input_length"] = DEFAULT_INPUT_MAX_LENGTH
66
+
67
+ yield
68
+ ml_context.clear()
69
+
70
+
71
+ app = FastAPI(title="Prism Studio API", lifespan=lifespan)
72
+
73
+ app.add_middleware(
74
+ CORSMiddleware,
75
+ allow_origins=["*"],
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
+ )
79
+
80
+
81
+ # ── Schemas ───────────────────────────────────────────────────────────────────
82
+ class SummarizeRequest(BaseModel):
83
+ text: str
84
+ engine: str = "bart" # "bart" | "gemini"
85
+ max_new_tokens: int = 128
86
+ min_new_tokens: int = 30
87
+ num_beams: int = 4
88
+ length_penalty: float = 1.5 # >1 encourages longer, more complete summaries
89
+ gemini_model: str = "gemini-3.0-flash"
90
+ polish: bool = False # if True, run Gemini to clean up BART's output
91
+
92
+
93
+ class SummarizeResponse(BaseModel):
94
+ summary: str
95
+ engine_used: str
96
+ chunks_processed: int = 1
97
+
98
+
99
+ # ── Sentence-aware text splitter ──────────────────────────────────────────────
100
+ def _split_sentences(text: str) -> List[str]:
101
+ """Split text into sentences respecting abbreviations."""
102
+ sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text.strip())
103
+ return [s.strip() for s in sentences if s.strip()]
104
+
105
+
106
+ def _build_sentence_chunks(text: str, tokenizer, max_tokens: int) -> List[str]:
107
+ """
108
+ Split text into chunks that respect sentence boundaries.
109
+ Each chunk is at most max_tokens tokens long.
110
+ Returns a list of text strings (not token IDs) β€” one per chunk.
111
+ """
112
+ sentences = _split_sentences(text)
113
+ chunks: List[str] = []
114
+ current_sentences: List[str] = []
115
+ current_len = 0
116
+
117
+ for sent in sentences:
118
+ sent_tokens = len(tokenizer.encode(sent, add_special_tokens=False))
119
+
120
+ # If adding this sentence would exceed the limit, flush current chunk
121
+ if current_len + sent_tokens > max_tokens and current_sentences:
122
+ chunks.append(" ".join(current_sentences))
123
+ # Keep the last sentence for overlap context
124
+ current_sentences = [current_sentences[-1]] if current_sentences else []
125
+ current_len = len(tokenizer.encode(current_sentences[0], add_special_tokens=False)) if current_sentences else 0
126
+
127
+ current_sentences.append(sent)
128
+ current_len += sent_tokens
129
+
130
+ if current_sentences:
131
+ chunks.append(" ".join(current_sentences))
132
+
133
+ return chunks
134
+
135
+
136
+ # ── BART: single-chunk inference ──────────────────────────────────────────────
137
+ def _bart_generate_one(text_chunk: str, request: SummarizeRequest) -> str:
138
+ """Summarise a single text chunk with BART."""
139
+ tokenizer = ml_context["tokenizer"]
140
+ model = ml_context["model"]
141
+ device = ml_context["device"]
142
+
143
+ tokenized = tokenizer(
144
+ text_chunk,
145
+ return_tensors="pt",
146
+ truncation=True,
147
+ max_length=DEFAULT_INPUT_MAX_LENGTH,
148
+ padding=False,
149
+ ).to(device)
150
+
151
+ try:
152
+ with torch.inference_mode():
153
+ # BART generation parameters
154
+ gen_kwargs = {
155
+ "max_new_tokens": request.max_new_tokens,
156
+ "min_length": request.min_new_tokens,
157
+ "length_penalty": request.length_penalty,
158
+ "num_beams": request.num_beams,
159
+ "early_stopping": True,
160
+ "no_repeat_ngram_size": 3,
161
+ "repetition_penalty": 1.5, # Strongly discourage hallucination by phrase reuse
162
+ }
163
+ generated = model.generate(
164
+ **tokenized,
165
+ **gen_kwargs
166
+ )
167
+ except torch.cuda.OutOfMemoryError:
168
+ raise HTTPException(status_code=500, detail="CUDA Out of Memory β€” try a shorter document or fewer beams.")
169
+ except Exception as e:
170
+ raise HTTPException(status_code=500, detail=f"BART generation failed: {e}")
171
+
172
+ return tokenizer.decode(generated[0], skip_special_tokens=True).strip()
173
+
174
+
175
+ # ── BART: hierarchical Map-Reduce ─────────────────────────────────────────────
176
+ def _bart_summarize(text: str, request: SummarizeRequest) -> Tuple[str, int]:
177
+ """
178
+ Sentence-aware Map-Reduce summarisation:
179
+ MAP: Split into sentence-boundary chunks β†’ summarise each
180
+ REDUCE: Summarise the combined chunk summaries
181
+ Returns (final_summary, num_chunks).
182
+ """
183
+ tokenizer = ml_context["tokenizer"]
184
+ total_tokens = len(tokenizer.encode(text, add_special_tokens=False))
185
+
186
+ # ── Single pass β€” text fits in BART's window ──────────────────────────────
187
+ if total_tokens <= CHUNK_SIZE:
188
+ return _bart_generate_one(text, request), 1
189
+
190
+ # ── MAP β€” sentence-aware chunking ─────────────────────────────────────────
191
+ chunks = _build_sentence_chunks(text, tokenizer, CHUNK_SIZE)
192
+ LOGGER.info(f"Chunked document into {len(chunks)} sentence-aware chunks")
193
+
194
+ chunk_summaries: List[str] = []
195
+ for i, chunk in enumerate(chunks):
196
+ LOGGER.info(f"Summarising chunk {i+1}/{len(chunks)}")
197
+ chunk_summaries.append(_bart_generate_one(chunk, request))
198
+
199
+ num_chunks = len(chunk_summaries)
200
+ combined = " ".join(chunk_summaries)
201
+
202
+ # ── REDUCE β€” summarise the combined chunk summaries ───────────────────────
203
+ combined_tokens = len(tokenizer.encode(combined, add_special_tokens=False))
204
+ if combined_tokens <= CHUNK_SIZE:
205
+ # Combined summaries are short enough for one final pass
206
+ final = _bart_generate_one(combined, request)
207
+ else:
208
+ # Recursively reduce (handles extremely long documents)
209
+ final, _ = _bart_summarize(combined, request)
210
+
211
+ return final, num_chunks
212
+
213
+
214
+ # ── Gemini polish (optional post-processing of BART output) ──────────────────
215
+ def _gemini_polish(original_text: str, rough_summary: str, gemini_model: str) -> str:
216
+ """Use Gemini to fact-check and rewrite BART's output based on the original document."""
217
+ if not gemini_client:
218
+ return rough_summary
219
+ prompt = (
220
+ "You are an expert editor. I will provide you with a SOURCE DOCUMENT and a ROUGH SUMMARY generated by a smaller AI.\n\n"
221
+ "Your task is to produce a highly polished, professional, and detailed summary of the SOURCE DOCUMENT.\n"
222
+ "1. Use the ROUGH SUMMARY as a starting point or inspiration.\n"
223
+ "2. If the ROUGH SUMMARY contains hallucinations or makes zero sense, IGNORE IT entirely and write a completely new, accurate summary based ONLY on the SOURCE DOCUMENT.\n"
224
+ "3. Ensure the final output is fluent, detailed, and directly captures the core message of the SOURCE document.\n\n"
225
+ f"SOURCE DOCUMENT:\n{original_text}\n\n"
226
+ f"ROUGH SUMMARY:\n{rough_summary}\n\n"
227
+ "POLISHED SUMMARY:"
228
+ )
229
+ response = gemini_client.models.generate_content(model=gemini_model, contents=prompt)
230
+ return response.text.strip()
231
+
232
+
233
+ # ── Main endpoint ─────────────────────────────────────────────────────────────
234
+ @app.post("/api/summarize", response_model=SummarizeResponse)
235
+ def summarize(request: SummarizeRequest):
236
+ cleaned_text = normalize_text(request.text)
237
+ if not cleaned_text:
238
+ raise HTTPException(status_code=400, detail="Please enter a document to summarize.")
239
+
240
+ # ── Pure Gemini path ─��────────────────────────────────────────────────────
241
+ if request.engine == "gemini":
242
+ if not gemini_client:
243
+ raise HTTPException(status_code=503, detail="Gemini API key not configured.")
244
+ try:
245
+ prompt = (
246
+ "You are an expert summarizer. Produce a concise, accurate, and well-written "
247
+ "summary of the following document. Preserve key facts and conclusions. "
248
+ "Do not add information that is not in the document.\n\n"
249
+ f"DOCUMENT:\n{cleaned_text}\n\nSUMMARY:"
250
+ )
251
+ response = gemini_client.models.generate_content(
252
+ model=request.gemini_model, contents=prompt
253
+ )
254
+ return SummarizeResponse(
255
+ summary=response.text.strip(),
256
+ engine_used=request.gemini_model,
257
+ chunks_processed=1,
258
+ )
259
+ except Exception as e:
260
+ raise HTTPException(status_code=500, detail=f"Gemini API error: {e}")
261
+
262
+ # ── BART path (with optional Gemini polish) ───────────────────────────────
263
+ final_summary, num_chunks = _bart_summarize(cleaned_text, request)
264
+
265
+ # Optional: use Gemini to clean up BART's output
266
+ if request.polish and gemini_client:
267
+ try:
268
+ LOGGER.info("Applying Grounded Gemini polish to BART output...")
269
+ final_summary = _gemini_polish(cleaned_text, final_summary, request.gemini_model)
270
+ engine_label = f"bart-large-xsum + {request.gemini_model} polish"
271
+ except Exception as e:
272
+ LOGGER.error(f"Gemini polish failed: {e}. Falling back to raw BART output.")
273
+ engine_label = f"bart-large-xsum (polish failed: {request.gemini_model})"
274
+ else:
275
+ engine_label = "bart-large-xsum" if num_chunks == 1 else f"bart-large-xsum (Γ—{num_chunks} chunks)"
276
+
277
+ return SummarizeResponse(
278
+ summary=final_summary,
279
+ engine_used=engine_label,
280
+ chunks_processed=num_chunks,
281
+ )
282
+
283
+
284
+ # ── Health check ──────────────────────────────────────────────────────────────
285
+ @app.get("/api/status")
286
+ def status():
287
+ return {
288
+ "bart": "loaded",
289
+ "gemini": "ready" if gemini_client else "no_key",
290
+ "bart_max_tokens": "unlimited (hierarchical chunking)",
291
+ }
292
+
293
+
294
+ app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")