Jaaccaa's picture
Update backend/main.py
8fa78af verified
Raw
History Blame Contribute Delete
5.98 kB
import os
import time
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from deep_translator import GoogleTranslator
from sentence_transformers import SentenceTransformer, util
import nlpaug.augmenter.word as naw
from groq import Groq
from dotenv import load_dotenv
# Load environment variables from the .env file (for local development)
load_dotenv()
# Application initialization
app = FastAPI(title="Hybrid Augmentation Pipeline API")
print("Loading Sentence-BERT model (this might take a moment on the first run)...")
# Using a lightweight, multilingual model (supports Polish and English)
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
print("Sentence-BERT model loaded successfully.")
print("Loading HerBERT model for EDA...")
# Using HerBERT for contextual word substitution in the Polish language
eda_aug = naw.ContextualWordEmbsAug(
model_path='allegro/herbert-base-cased',
action="substitute",
device="cpu", # Switch to "cuda" if appropriate GPU hardware is available
aug_p=0.1, # Substitute a maximum of 10% of words in the input sequence
top_k=10
)
print("HerBERT model loaded successfully.")
print("Configuring Groq API...")
# Securely fetching the API key from environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("CRITICAL ERROR: GROQ_API_KEY environment variable is missing!")
groq_client = Groq(api_key=GROQ_API_KEY)
print("Groq API configured successfully.")
# CORS configuration (allows the frontend application to communicate with the backend)
app.add_middleware(
CORSMiddleware,
allow_origins=["https://data-augmentation-sigma.vercel.app"],
allow_credentials=True,
allow_methods=["https://data-augmentation-sigma.vercel.app"],
allow_headers=["https://data-augmentation-sigma.vercel.app"],
)
# 1. Data Models (Expected JSON payloads from the frontend client)
class AugmentRequest(BaseModel):
text: str
method: str
pivot_lang: str = "en"
eda_p: float = 0.1 # Percentage of words to substitute during EDA
class FilterRequest(BaseModel):
original: str
augmented: str
threshold: float = 0.8 # Cutoff threshold for the semantic similarity filter
# 2. Endpoint: Data Augmentation
@app.post("/augment")
async def generate_paraphrase(request: AugmentRequest):
try:
# Simulated network delay for UI feedback
time.sleep(1)
if request.method == "EDA":
# Dynamically adjust the substitution probability based on client input
eda_aug.aug_p = request.eda_p
result = eda_aug.augment(request.text)[0]
elif request.method == "BT":
# Retrieve the target pivot language from the request
pivot_lang = request.pivot_lang.lower()
# Step 1: Source (PL) -> Pivot Language
intermediate_text = GoogleTranslator(source='pl', target=pivot_lang).translate(request.text)
# Step 2: Pivot Language -> Source (PL)
result = GoogleTranslator(source=pivot_lang, target='pl').translate(intermediate_text)
return {
"original": request.text,
"augmented": result,
"method": request.method,
"pivot_lang": pivot_lang,
"intermediate": intermediate_text
}
elif request.method == "LLM":
# LLM Generation (Llama 3 via Groq API)
prompt = f"""You are an expert NLP assistant. Paraphrase the following Polish sentence in Polish.
Maintain the exact same meaning and sentiment, but use a different syntactic structure or synonyms.
Return ONLY the single paraphrased sentence. Do not include any introductions, notes, or comments.
Input sentence: {request.text}"""
chat_completion = groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=100,
)
result = chat_completion.choices[0].message.content.strip()
# Strip enclosing quotation marks if generated by the model
if result.startswith('"') and result.endswith('"'):
result = result[1:-1]
else:
raise HTTPException(status_code=400, detail="Unknown augmentation method requested.")
return {"original": request.text, "augmented": result, "method": request.method}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 3. Endpoint: Semantic Filtration (Sentence-BERT)
@app.post("/filter")
async def filter_sentence(request: FilterRequest):
try:
# Compute dense vector embeddings for both text sequences
emb1 = model.encode(request.original, convert_to_tensor=True)
emb2 = model.encode(request.augmented, convert_to_tensor=True)
# Compute cosine similarity between the embeddings
cosine_scores = util.cos_sim(emb1, emb2)
sim_score = cosine_scores[0][0].item() # Extract the numerical value from the tensor
# Floating-point precision safeguard (clamp between 0.0 and 1.0)
sim_score = min(max(sim_score, 0.0), 1.0)
# Evaluate against the requested threshold
passed = sim_score >= request.threshold
return {
"similarity": round(sim_score, 3),
"passed": passed,
"threshold": request.threshold
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 4. Endpoint: Health Check
@app.get("/")
def read_root():
return {"status": "Backend is running successfully."}