Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| import os | |
| import re | |
| # Load the model and tokenizer | |
| model_name = "google/flan-t5-base" | |
| hf_token = os.environ.get("HF_TOKEN") # Set as a secret in Hugging Face Space settings | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=hf_token) | |
| # Move the model to CPU (or GPU if available) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Function to generate a clean prompt | |
| def generate_prompt(original, translation): | |
| return ( | |
| f"Rate the quality of this translation from 0 (poor) to 1 (excellent). " | |
| f"Only respond with a number.\n\n" | |
| f"Source: {original}\n" | |
| f"Translation: {translation}\n" | |
| f"Score:" | |
| ) | |
| # Main prediction function | |
| def predict_scores(file): | |
| df = pd.read_csv(file.name, sep="\t") | |
| scores = [] | |
| for _, row in df.iterrows(): | |
| prompt = generate_prompt(row["original"], row["translation"]) | |
| # Tokenize and send to model | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, max_new_tokens=10) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Debug print (optional) | |
| print("Response:", response) | |
| # Extract numeric score using regex | |
| match = re.search(r"\b([01](?:\.\d+)?)\b", response) | |
| if match: | |
| score_val = float(match.group(1)) | |
| score_val = max(0, min(score_val, 1)) # Clamp between 0 and 1 | |
| else: | |
| score_val = -1 # fallback if model output is invalid | |
| scores.append(score_val) | |
| df["predicted_score"] = scores | |
| return df | |
| # Gradio UI | |
| iface = gr.Interface( | |
| fn=predict_scores, | |
| inputs=gr.File(label="Upload dev.tsv"), | |
| outputs=gr.Dataframe(label="QE Output with Predicted Score"), | |
| title="MT QE with FLAN-T5-Base", | |
| description="Upload a dev.tsv file with columns: 'original' and 'translation'." | |
| ) | |
| # Launch app | |
| iface.launch() | |