Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from deep_translator import GoogleTranslator | |
| from sentence_transformers import SentenceTransformer, util | |
| import nlpaug.augmenter.word as naw | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| # Load environment variables from the .env file (for local development) | |
| load_dotenv() | |
| # Application initialization | |
| app = FastAPI(title="Hybrid Augmentation Pipeline API") | |
| print("Loading Sentence-BERT model (this might take a moment on the first run)...") | |
| # Using a lightweight, multilingual model (supports Polish and English) | |
| model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') | |
| print("Sentence-BERT model loaded successfully.") | |
| print("Loading HerBERT model for EDA...") | |
| # Using HerBERT for contextual word substitution in the Polish language | |
| eda_aug = naw.ContextualWordEmbsAug( | |
| model_path='allegro/herbert-base-cased', | |
| action="substitute", | |
| device="cpu", # Switch to "cuda" if appropriate GPU hardware is available | |
| aug_p=0.1, # Substitute a maximum of 10% of words in the input sequence | |
| top_k=10 | |
| ) | |
| print("HerBERT model loaded successfully.") | |
| print("Configuring Groq API...") | |
| # Securely fetching the API key from environment variables | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("CRITICAL ERROR: GROQ_API_KEY environment variable is missing!") | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| print("Groq API configured successfully.") | |
| # CORS configuration (allows the frontend application to communicate with the backend) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://data-augmentation-sigma.vercel.app"], | |
| allow_credentials=True, | |
| allow_methods=["https://data-augmentation-sigma.vercel.app"], | |
| allow_headers=["https://data-augmentation-sigma.vercel.app"], | |
| ) | |
| # 1. Data Models (Expected JSON payloads from the frontend client) | |
| class AugmentRequest(BaseModel): | |
| text: str | |
| method: str | |
| pivot_lang: str = "en" | |
| eda_p: float = 0.1 # Percentage of words to substitute during EDA | |
| class FilterRequest(BaseModel): | |
| original: str | |
| augmented: str | |
| threshold: float = 0.8 # Cutoff threshold for the semantic similarity filter | |
| # 2. Endpoint: Data Augmentation | |
| async def generate_paraphrase(request: AugmentRequest): | |
| try: | |
| # Simulated network delay for UI feedback | |
| time.sleep(1) | |
| if request.method == "EDA": | |
| # Dynamically adjust the substitution probability based on client input | |
| eda_aug.aug_p = request.eda_p | |
| result = eda_aug.augment(request.text)[0] | |
| elif request.method == "BT": | |
| # Retrieve the target pivot language from the request | |
| pivot_lang = request.pivot_lang.lower() | |
| # Step 1: Source (PL) -> Pivot Language | |
| intermediate_text = GoogleTranslator(source='pl', target=pivot_lang).translate(request.text) | |
| # Step 2: Pivot Language -> Source (PL) | |
| result = GoogleTranslator(source=pivot_lang, target='pl').translate(intermediate_text) | |
| return { | |
| "original": request.text, | |
| "augmented": result, | |
| "method": request.method, | |
| "pivot_lang": pivot_lang, | |
| "intermediate": intermediate_text | |
| } | |
| elif request.method == "LLM": | |
| # LLM Generation (Llama 3 via Groq API) | |
| prompt = f"""You are an expert NLP assistant. Paraphrase the following Polish sentence in Polish. | |
| Maintain the exact same meaning and sentiment, but use a different syntactic structure or synonyms. | |
| Return ONLY the single paraphrased sentence. Do not include any introductions, notes, or comments. | |
| Input sentence: {request.text}""" | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.7, | |
| max_tokens=100, | |
| ) | |
| result = chat_completion.choices[0].message.content.strip() | |
| # Strip enclosing quotation marks if generated by the model | |
| if result.startswith('"') and result.endswith('"'): | |
| result = result[1:-1] | |
| else: | |
| raise HTTPException(status_code=400, detail="Unknown augmentation method requested.") | |
| return {"original": request.text, "augmented": result, "method": request.method} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 3. Endpoint: Semantic Filtration (Sentence-BERT) | |
| async def filter_sentence(request: FilterRequest): | |
| try: | |
| # Compute dense vector embeddings for both text sequences | |
| emb1 = model.encode(request.original, convert_to_tensor=True) | |
| emb2 = model.encode(request.augmented, convert_to_tensor=True) | |
| # Compute cosine similarity between the embeddings | |
| cosine_scores = util.cos_sim(emb1, emb2) | |
| sim_score = cosine_scores[0][0].item() # Extract the numerical value from the tensor | |
| # Floating-point precision safeguard (clamp between 0.0 and 1.0) | |
| sim_score = min(max(sim_score, 0.0), 1.0) | |
| # Evaluate against the requested threshold | |
| passed = sim_score >= request.threshold | |
| return { | |
| "similarity": round(sim_score, 3), | |
| "passed": passed, | |
| "threshold": request.threshold | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 4. Endpoint: Health Check | |
| def read_root(): | |
| return {"status": "Backend is running successfully."} |