Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Request, Depends, HTTPException, APIRouter | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import HfApi | |
| from io import BytesIO | |
| import re | |
| import docx | |
| from pathlib import Path | |
| from docx.enum.text import WD_COLOR_INDEX | |
| from PyPDF2 import PdfReader | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from Ai_rewriter.rewriter_fixed import rewrite_text | |
| import uuid | |
| import stripe | |
| from pydantic import BaseModel | |
| from supabase import create_client, Client | |
| from dotenv import load_dotenv | |
| import subprocess | |
| import tempfile | |
| import os | |
| import shlex | |
| load_dotenv() | |
| # === CONFIG === | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| HF_DATASET_REPO = "AlyanAkram/StealthReports" | |
| CORS_ORIGINS = ["http://localhost:5173", "https://stealth-writer.vercel.app/"] | |
| stripe.api_key = os.getenv("STRIPE_SECRET_KEY") | |
| print("๐ Stripe key loaded:", stripe.api_key, len(stripe.api_key)) | |
| supabase: Client | None = None | |
| PRICE_MAP = { | |
| "basic": "price_1RyxK4KiaPeHFPzzwBG5C5Rf", | |
| "premium": "price_1RyxKBKiaPeHFPzz5oDy6m2c", | |
| } | |
| # === FastAPI app setup === | |
| app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=CORS_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # === Load model on startup === | |
| analyze_text = None | |
| generate_pdf_report = None | |
| async def load_model(): | |
| global analyze_text, generate_pdf_report, supabase | |
| from detector.custom_model import analyze_text as at, generate_pdf_report as gpr | |
| analyze_text = at | |
| generate_pdf_report = gpr | |
| supabase = create_client( | |
| os.getenv("SUPABASE_URL"), | |
| os.getenv("SUPABASE_KEY") | |
| ) | |
| # === Utils === | |
| def extract_text(file: UploadFile, ext: str) -> str: | |
| content = file.file.read() | |
| file_bytes = BytesIO(content) | |
| if ext == ".txt": | |
| return content.decode("utf-8", errors="ignore") | |
| elif ext == ".pdf": | |
| reader = PdfReader(file_bytes) | |
| return "".join([page.extract_text() or "" for page in reader.pages]) | |
| elif ext == ".docx": | |
| doc = docx.Document(file_bytes) | |
| return "\n".join([para.text for para in doc.paragraphs]) | |
| else: | |
| raise ValueError("Unsupported file type") | |
| def sanitize_filename(name): | |
| return re.sub(r"[^\w\-_.]", "_", name) | |
| def upload_to_dataset(path: str, content: BytesIO, token: str) -> str: | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=content, | |
| path_in_repo=path, | |
| repo_id=HF_DATASET_REPO, | |
| repo_type="dataset", | |
| token=token, | |
| ) | |
| return f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{path}" | |
| # === Main endpoint === | |
| async def detect(file: UploadFile = File(...)): | |
| try: | |
| ext = os.path.splitext(file.filename)[1].lower() | |
| if ext not in [".txt", ".pdf", ".docx"]: | |
| raise ValueError("Unsupported file format") | |
| text = extract_text(file, ext) | |
| result = analyze_text(text) | |
| filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8] | |
| docx_buffer = BytesIO() | |
| doc = docx.Document() | |
| doc.add_heading("AI Detection Summary", level=1) | |
| doc.add_paragraph(f"Overall AI %: {result['overall_ai_percent']}%") | |
| doc.add_paragraph(f"Total Sentences: {result['total_sentences']}") | |
| doc.add_paragraph(f"AI Sentences: {result['ai_sentences']}") | |
| doc.add_paragraph("Sentences detected as AI are highlighted in cyan.\n") | |
| doc.add_heading("Sentence Analysis", level=2) | |
| paragraph = doc.add_paragraph() | |
| for para in result["results"]: | |
| for sentence, is_ai, _ in para: | |
| if not isinstance(sentence, str) or not sentence.strip(): | |
| continue | |
| run = paragraph.add_run(sentence + " ") | |
| if is_ai: | |
| run.font.highlight_color = WD_COLOR_INDEX.TURQUOISE | |
| doc.save(docx_buffer) | |
| docx_buffer.seek(0) | |
| pdf_buffer = generate_pdf_report(result, filename_base) | |
| docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN) | |
| pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN) | |
| return { | |
| "success": True, | |
| "score": { | |
| **{k: v for k, v in result.items() if k != "results"}, | |
| "results": [ | |
| [{"sentence": s, "is_ai": is_ai, "ai_score": round(ai_score * 100, 2)} for s, is_ai, ai_score in para] | |
| for para in result["results"] | |
| ] | |
| }, | |
| "docx_url": docx_url, | |
| "pdf_url": pdf_url | |
| } | |
| except Exception as e: | |
| return JSONResponse(content={"success": False, "error": str(e)}, status_code=500) | |
| router = APIRouter(prefix="/api/payments", tags=["payments"]) | |
| class CheckoutReq(BaseModel): | |
| plan: str | |
| success_url: str | |
| cancel_url: str | |
| email: str | |
| user_id: str | |
| async def create_session(data: CheckoutReq): | |
| if data.plan not in PRICE_MAP: | |
| raise HTTPException(400, "Unknown plan") | |
| session = stripe.checkout.Session.create( | |
| mode="subscription", | |
| payment_method_types=["card"], | |
| line_items=[{"price": PRICE_MAP[data.plan], "quantity": 1}], | |
| customer_email=data.email, | |
| client_reference_id=data.user_id, # โ Store Supabase user.id | |
| success_url=data.success_url + "?session_id={CHECKOUT_SESSION_ID}", | |
| cancel_url=data.cancel_url, | |
| ) | |
| return {"url": session.url} | |
| async def stripe_webhook(request: Request): | |
| payload = await request.body() | |
| sig_header = request.headers.get("stripe-signature") | |
| webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") | |
| try: | |
| event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret) | |
| except stripe.error.SignatureVerificationError: | |
| return JSONResponse(status_code=400, content={"error": "Invalid signature"}) | |
| # โ First-time checkout completed | |
| if event["type"] == "checkout.session.completed": | |
| session = event["data"]["object"] | |
| user_id = session.get("client_reference_id") # Supabase user.id | |
| subscription_id = session.get("subscription") | |
| try: | |
| subscription = stripe.Subscription.retrieve(subscription_id) | |
| price_id = subscription["items"]["data"][0]["price"]["id"] | |
| plan = next((k for k, v in PRICE_MAP.items() if v == price_id), None) | |
| if plan and user_id and supabase: | |
| print(f"Updating plan to {plan} for Supabase user {user_id}") | |
| response = supabase.rpc("update_user_plan", { | |
| "uid": user_id, | |
| "new_plan": plan | |
| }).execute() | |
| print("โ Supabase update response:", response) | |
| except Exception as e: | |
| print("Webhook error while updating Supabase (checkout):", str(e)) | |
| # โ Subscription renewals (monthly/annual billing) | |
| elif event["type"] == "invoice.payment_succeeded": | |
| invoice = event["data"]["object"] | |
| subscription_id = invoice.get("subscription") | |
| try: | |
| subscription = stripe.Subscription.retrieve(subscription_id) | |
| user_id = subscription.get("metadata", {}).get("user_id") \ | |
| or subscription.get("client_reference_id") | |
| price_id = subscription["items"]["data"][0]["price"]["id"] | |
| plan = next((k for k, v in PRICE_MAP.items() if v == price_id), None) | |
| if plan and user_id and supabase: | |
| print(f"Renewal: Updating plan to {plan} for Supabase user {user_id}") | |
| response = supabase.rpc("update_user_plan", { | |
| "uid": user_id, | |
| "new_plan": plan | |
| }).execute() | |
| print("โ Supabase update response (renewal):", response) | |
| except Exception as e: | |
| print("Webhook error while updating Supabase (renewal):", str(e)) | |
| # โ Subscription cancelled/expired | |
| elif event["type"] == "customer.subscription.deleted": | |
| subscription = event["data"]["object"] | |
| user_id = subscription.get("metadata", {}).get("user_id") | |
| try: | |
| if user_id and supabase: | |
| print(f"Downgrading Supabase user {user_id} to free plan") | |
| response = supabase.rpc("update_user_plan", { | |
| "uid": user_id, | |
| "new_plan": "free" | |
| }).execute() | |
| print("โ Supabase update response (canceled):", response) | |
| except Exception as e: | |
| print("Webhook error while downgrading Supabase:", str(e)) | |
| return {"status": "success"} | |
| # Use environment variable if running on Hugging Face Spaces | |
| output_dir_env = os.environ.get("REWRITTEN_OUTPUTS_DIR", "rewritten_outputs") | |
| OUTPUT_DIR = Path(output_dir_env) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Define the TextInput model | |
| class TextInput(BaseModel): | |
| text: str | |
| async def extract_text_from_file(file: UploadFile) -> str: | |
| """Wrapper to match your existing extract_text but works with async file read.""" | |
| ext = os.path.splitext(file.filename)[1].lower() | |
| if ext not in [".txt", ".pdf", ".docx"]: | |
| raise ValueError("Unsupported file format") | |
| return extract_text(file, ext) | |
| class RewriteRequest(BaseModel): | |
| text: str | |
| async def rewrite_endpoint(file: UploadFile = File(...)): | |
| try: | |
| ext = os.path.splitext(file.filename)[1].lower() | |
| if ext not in [".txt", ".pdf", ".docx"]: | |
| raise ValueError("Unsupported file format") | |
| # 1. Save input file to temp | |
| temp_input = Path(tempfile.gettempdir()) / f"input_{uuid.uuid4().hex}{ext}" | |
| with open(temp_input, "wb") as f: | |
| f.write(await file.read()) | |
| # 2. Temp output file | |
| temp_output = Path(tempfile.gettempdir()) / f"rewritten_{uuid.uuid4().hex}{ext}" | |
| # 3. Run rewriter_fixed.py on file | |
| subprocess.run( | |
| ["python", "-X", "utf8", "Ai_rewriter/rewriter_fixed.py", str(temp_input), str(temp_output)], | |
| check=True | |
| ) | |
| if not temp_output.exists(): | |
| raise HTTPException(status_code=500, detail="Rewriter did not produce an output file.") | |
| # 4. Create DOCX + PDF versions | |
| rewritten_text = "" | |
| if ext == ".txt": | |
| rewritten_text = temp_output.read_text(encoding="utf-8") | |
| elif ext == ".docx": | |
| import docx | |
| doc = docx.Document(temp_output) | |
| rewritten_text = "\n".join([p.text for p in doc.paragraphs]) | |
| elif ext == ".pdf": | |
| from PyPDF2 import PdfReader | |
| reader = PdfReader(str(temp_output)) | |
| rewritten_text = "".join([page.extract_text() or "" for page in reader.pages]) | |
| # Create DOCX | |
| docx_buffer = BytesIO() | |
| doc = docx.Document() | |
| doc.add_paragraph(rewritten_text) | |
| doc.save(docx_buffer) | |
| docx_buffer.seek(0) | |
| # Create PDF | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib.units import inch | |
| pdf_buffer = BytesIO() | |
| c = canvas.Canvas(pdf_buffer, pagesize=A4) | |
| width, height = A4 | |
| text_object = c.beginText(0.5 * inch, height - 0.5 * inch) | |
| text_object.setFont("Times-Roman", 12) | |
| for line in rewritten_text.split("\n"): | |
| text_object.textLine(line) | |
| c.drawText(text_object) | |
| c.showPage() | |
| c.save() | |
| pdf_buffer.seek(0) | |
| # 5. Upload to Hugging Face dataset like detect | |
| filename_base = sanitize_filename(os.path.splitext(file.filename)[0]) + "_" + str(uuid.uuid4())[:8] | |
| docx_url = upload_to_dataset(f"{filename_base}.docx", docx_buffer, HF_TOKEN) | |
| pdf_url = upload_to_dataset(f"{filename_base}.pdf", pdf_buffer, HF_TOKEN) | |
| # 6. Cleanup | |
| temp_input.unlink(missing_ok=True) | |
| temp_output.unlink(missing_ok=True) | |
| return { | |
| "success": True, | |
| "docx_url": docx_url, | |
| "pdf_url": pdf_url | |
| } | |
| except subprocess.CalledProcessError as e: | |
| raise HTTPException(status_code=500, detail=f"Rewriter process failed: {e}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def rewrite_text(req: RewriteRequest): | |
| temp_in_path = None | |
| temp_out_path = None | |
| try: | |
| # Always set file paths before subprocess to avoid reference errors | |
| temp_in = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") | |
| temp_in.write(req.text) | |
| temp_in.close() | |
| temp_in_path = temp_in.name | |
| temp_out_path = Path(tempfile.gettempdir()) / f"rewritten_{os.path.basename(temp_in_path)}" | |
| # Prepare CLI exactly like your manual run, forcing UTF-8 | |
| cmd = [ | |
| "python", "-X", "utf8", | |
| "Ai_rewriter/rewriter_fixed.py", | |
| req.text, # pass raw text directly | |
| str(temp_out_path) | |
| ] | |
| print("๐ Running:", shlex.join(cmd)) | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| encoding="utf-8", | |
| errors="replace", | |
| shell=False | |
| ) | |
| print("๐น STDOUT:\n", result.stdout) | |
| print("๐น STDERR:\n", result.stderr) | |
| if result.returncode != 0: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model process failed: {result.stderr or 'Unknown error'}" | |
| ) | |
| if not temp_out_path.exists(): | |
| raise HTTPException(status_code=500, detail="Output file not found.") | |
| rewritten_text = temp_out_path.read_text(encoding="utf-8") | |
| return {"rewritten_text": rewritten_text} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Cleanup safely | |
| if temp_in_path and os.path.exists(temp_in_path): | |
| os.remove(temp_in_path) | |
| if temp_out_path and os.path.exists(temp_out_path): | |
| os.remove(temp_out_path) | |
| app.include_router(router) |