import os import time from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from deep_translator import GoogleTranslator from sentence_transformers import SentenceTransformer, util import nlpaug.augmenter.word as naw from groq import Groq from dotenv import load_dotenv # Load environment variables from the .env file (for local development) load_dotenv() # Application initialization app = FastAPI(title="Hybrid Augmentation Pipeline API") print("Loading Sentence-BERT model (this might take a moment on the first run)...") # Using a lightweight, multilingual model (supports Polish and English) model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') print("Sentence-BERT model loaded successfully.") print("Loading HerBERT model for EDA...") # Using HerBERT for contextual word substitution in the Polish language eda_aug = naw.ContextualWordEmbsAug( model_path='allegro/herbert-base-cased', action="substitute", device="cpu", # Switch to "cuda" if appropriate GPU hardware is available aug_p=0.1, # Substitute a maximum of 10% of words in the input sequence top_k=10 ) print("HerBERT model loaded successfully.") print("Configuring Groq API...") # Securely fetching the API key from environment variables GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: raise ValueError("CRITICAL ERROR: GROQ_API_KEY environment variable is missing!") groq_client = Groq(api_key=GROQ_API_KEY) print("Groq API configured successfully.") # CORS configuration (allows the frontend application to communicate with the backend) app.add_middleware( CORSMiddleware, allow_origins=["https://data-augmentation-sigma.vercel.app"], allow_credentials=True, allow_methods=["https://data-augmentation-sigma.vercel.app"], allow_headers=["https://data-augmentation-sigma.vercel.app"], ) # 1. Data Models (Expected JSON payloads from the frontend client) class AugmentRequest(BaseModel): text: str method: str pivot_lang: str = "en" eda_p: float = 0.1 # Percentage of words to substitute during EDA class FilterRequest(BaseModel): original: str augmented: str threshold: float = 0.8 # Cutoff threshold for the semantic similarity filter # 2. Endpoint: Data Augmentation @app.post("/augment") async def generate_paraphrase(request: AugmentRequest): try: # Simulated network delay for UI feedback time.sleep(1) if request.method == "EDA": # Dynamically adjust the substitution probability based on client input eda_aug.aug_p = request.eda_p result = eda_aug.augment(request.text)[0] elif request.method == "BT": # Retrieve the target pivot language from the request pivot_lang = request.pivot_lang.lower() # Step 1: Source (PL) -> Pivot Language intermediate_text = GoogleTranslator(source='pl', target=pivot_lang).translate(request.text) # Step 2: Pivot Language -> Source (PL) result = GoogleTranslator(source=pivot_lang, target='pl').translate(intermediate_text) return { "original": request.text, "augmented": result, "method": request.method, "pivot_lang": pivot_lang, "intermediate": intermediate_text } elif request.method == "LLM": # LLM Generation (Llama 3 via Groq API) prompt = f"""You are an expert NLP assistant. Paraphrase the following Polish sentence in Polish. Maintain the exact same meaning and sentiment, but use a different syntactic structure or synonyms. Return ONLY the single paraphrased sentence. Do not include any introductions, notes, or comments. Input sentence: {request.text}""" chat_completion = groq_client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", temperature=0.7, max_tokens=100, ) result = chat_completion.choices[0].message.content.strip() # Strip enclosing quotation marks if generated by the model if result.startswith('"') and result.endswith('"'): result = result[1:-1] else: raise HTTPException(status_code=400, detail="Unknown augmentation method requested.") return {"original": request.text, "augmented": result, "method": request.method} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # 3. Endpoint: Semantic Filtration (Sentence-BERT) @app.post("/filter") async def filter_sentence(request: FilterRequest): try: # Compute dense vector embeddings for both text sequences emb1 = model.encode(request.original, convert_to_tensor=True) emb2 = model.encode(request.augmented, convert_to_tensor=True) # Compute cosine similarity between the embeddings cosine_scores = util.cos_sim(emb1, emb2) sim_score = cosine_scores[0][0].item() # Extract the numerical value from the tensor # Floating-point precision safeguard (clamp between 0.0 and 1.0) sim_score = min(max(sim_score, 0.0), 1.0) # Evaluate against the requested threshold passed = sim_score >= request.threshold return { "similarity": round(sim_score, 3), "passed": passed, "threshold": request.threshold } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # 4. Endpoint: Health Check @app.get("/") def read_root(): return {"status": "Backend is running successfully."}