Spaces:
Paused
Paused
tt
Browse files- Dockerfile +18 -15
- app.py +41 -57
- requirements.txt +1 -2
Dockerfile
CHANGED
|
@@ -1,36 +1,39 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
ENV OMP_NUM_THREADS=4
|
| 5 |
ENV DISABLE_TRITON=1
|
| 6 |
ENV ACCELERATE_USE_DEEPSPEED=0
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
# Install
|
| 9 |
RUN apt-get update && apt-get install -y \
|
| 10 |
-
git wget curl
|
| 11 |
-
libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev \
|
| 12 |
ffmpeg libsm6 libxext6 libgl1-mesa-glx \
|
| 13 |
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
|
| 17 |
RUN pip install --upgrade pip
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
|
| 21 |
-
|
| 22 |
-
# ---- 2. Install requirements (without flash-attn) ----
|
| 23 |
WORKDIR /app
|
| 24 |
COPY requirements.txt /app/requirements.txt
|
| 25 |
-
RUN
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
RUN pip install --no-build-isolation flash-attn
|
| 29 |
|
| 30 |
-
# Copy
|
| 31 |
COPY . /app
|
| 32 |
|
|
|
|
| 33 |
EXPOSE 7860
|
| 34 |
|
| 35 |
-
#
|
| 36 |
CMD ["python", "app.py"]
|
|
|
|
| 1 |
+
# Use official PyTorch with CUDA 12.1 (works with flash-attn)
|
| 2 |
+
FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel
|
| 3 |
|
| 4 |
ENV DEBIAN_FRONTEND=noninteractive
|
| 5 |
ENV OMP_NUM_THREADS=4
|
| 6 |
ENV DISABLE_TRITON=1
|
| 7 |
ENV ACCELERATE_USE_DEEPSPEED=0
|
| 8 |
+
ENV TRANSFORMERS_VERBOSITY=info
|
| 9 |
+
ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 10 |
+
ENV FLASH_ATTENTION_FORCE=1
|
| 11 |
|
| 12 |
+
# Install system dependencies
|
| 13 |
RUN apt-get update && apt-get install -y \
|
| 14 |
+
git wget curl build-essential python3-dev \
|
|
|
|
| 15 |
ffmpeg libsm6 libxext6 libgl1-mesa-glx \
|
| 16 |
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
|
| 18 |
+
# Upgrade pip first
|
|
|
|
| 19 |
RUN pip install --upgrade pip
|
| 20 |
|
| 21 |
+
# Copy requirements (without flash-attn)
|
|
|
|
|
|
|
|
|
|
| 22 |
WORKDIR /app
|
| 23 |
COPY requirements.txt /app/requirements.txt
|
| 24 |
+
RUN grep -v "flash-attn" requirements.txt > requirements-clean.txt
|
| 25 |
+
|
| 26 |
+
# Install all Python deps except flash-attn
|
| 27 |
+
RUN pip install --no-cache-dir -r requirements-clean.txt
|
| 28 |
|
| 29 |
+
# Install flash-attn last to ensure Torch is ready
|
| 30 |
+
RUN pip install --no-build-isolation flash-attn==2.8.2
|
| 31 |
|
| 32 |
+
# Copy application
|
| 33 |
COPY . /app
|
| 34 |
|
| 35 |
+
# Expose Gradio
|
| 36 |
EXPOSE 7860
|
| 37 |
|
| 38 |
+
# Default command to launch your app
|
| 39 |
CMD ["python", "app.py"]
|
app.py
CHANGED
|
@@ -275,10 +275,7 @@ def get_data_status():
|
|
| 275 |
"""Get data download status"""
|
| 276 |
return f"{data_download_status['message']}"
|
| 277 |
|
| 278 |
-
|
| 279 |
def run_inference(query, document_title, document_content, checkpoint="latest"):
|
| 280 |
-
import torch
|
| 281 |
-
|
| 282 |
global current_model, current_tokenizer
|
| 283 |
|
| 284 |
# Load the model if not already loaded
|
|
@@ -295,49 +292,32 @@ def run_inference(query, document_title, document_content, checkpoint="latest"):
|
|
| 295 |
else:
|
| 296 |
load_model_and_tokenizer(checkpoint)
|
| 297 |
|
| 298 |
-
# Prepare prompt
|
| 299 |
-
prompt =
|
| 300 |
-
|
| 301 |
-
Query:
|
| 302 |
-
{query}
|
| 303 |
-
|
| 304 |
-
Document:
|
| 305 |
-
title: {document_title}
|
| 306 |
-
content: {document_content}
|
| 307 |
-
"""
|
| 308 |
-
|
| 309 |
-
# Helper function to score log-probability
|
| 310 |
-
def score_response(model, tokenizer, prompt, response):
|
| 311 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 312 |
-
labels = tokenizer(response, return_tensors="pt").to(model.device)
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
seq_logprob += log_probs[0, inputs.input_ids.shape[1] + i - 1, token_id].item()
|
| 331 |
-
count += 1
|
| 332 |
-
|
| 333 |
-
return seq_logprob / max(count, 1)
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
score_irrelevant = score_response(current_model, current_tokenizer, prompt, "Irrelevant")
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
|
|
|
| 341 |
|
| 342 |
|
| 343 |
def list_checkpoints():
|
|
@@ -432,19 +412,19 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
| 432 |
|
| 433 |
import time
|
| 434 |
|
| 435 |
-
|
|
|
|
| 436 |
import pandas as pd
|
| 437 |
|
| 438 |
if csv_file is None:
|
| 439 |
raise ValueError("No CSV file uploaded.")
|
| 440 |
|
| 441 |
-
# Gradio File can be
|
| 442 |
csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
|
| 443 |
if csv_path is None:
|
| 444 |
raise ValueError("Invalid file input from Gradio.")
|
| 445 |
|
| 446 |
df = pd.read_csv(csv_path)
|
| 447 |
-
|
| 448 |
if "prompt" not in df.columns:
|
| 449 |
raise ValueError("CSV must have a 'prompt' column")
|
| 450 |
|
|
@@ -468,7 +448,7 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
| 468 |
correct = 0
|
| 469 |
total = len(prompts)
|
| 470 |
|
| 471 |
-
#
|
| 472 |
output_path = "/tmp/batch_inference_results.csv"
|
| 473 |
|
| 474 |
for i in range(0, total, batch_size):
|
|
@@ -492,40 +472,44 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
| 492 |
)
|
| 493 |
|
| 494 |
batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
| 495 |
for prompt, decoded in zip(batch_prompts, batch_decoded):
|
| 496 |
-
response = decoded[len(prompt):].strip()
|
| 497 |
-
if
|
| 498 |
-
pred = "
|
|
|
|
|
|
|
| 499 |
else:
|
| 500 |
-
pred =
|
| 501 |
predictions.append(pred)
|
| 502 |
|
| 503 |
-
#
|
| 504 |
if "chosen" in df.columns:
|
| 505 |
for j, pred in enumerate(predictions[-len(batch_prompts):]):
|
| 506 |
idx = i + j
|
| 507 |
if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
|
| 508 |
correct += 1
|
| 509 |
|
| 510 |
-
#
|
| 511 |
-
progress = (i + batch_size) / total * 100
|
| 512 |
df_partial = df.copy()
|
| 513 |
df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
|
| 514 |
df_partial.to_csv(output_path, index=False)
|
| 515 |
|
|
|
|
|
|
|
| 516 |
stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
|
| 517 |
if "chosen" in df.columns:
|
| 518 |
-
stats += f"\nCurrent Accuracy: {correct /
|
| 519 |
|
| 520 |
-
#
|
| 521 |
yield output_path, stats
|
| 522 |
|
| 523 |
# Final stats
|
| 524 |
final_stats = f"✅ Processed {total} rows"
|
| 525 |
if "chosen" in df.columns:
|
| 526 |
final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
|
| 527 |
-
yield output_path, final_stats
|
| 528 |
|
|
|
|
| 529 |
|
| 530 |
csv_infer_btn = gr.Button("Run Batch Inference")
|
| 531 |
csv_infer_btn.click(
|
|
|
|
| 275 |
"""Get data download status"""
|
| 276 |
return f"{data_download_status['message']}"
|
| 277 |
|
|
|
|
| 278 |
def run_inference(query, document_title, document_content, checkpoint="latest"):
|
|
|
|
|
|
|
| 279 |
global current_model, current_tokenizer
|
| 280 |
|
| 281 |
# Load the model if not already loaded
|
|
|
|
| 292 |
else:
|
| 293 |
load_model_and_tokenizer(checkpoint)
|
| 294 |
|
| 295 |
+
# Prepare prompt like training
|
| 296 |
+
prompt = format_prompt_for_inference(query, document_title, document_content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
# Tokenize
|
| 299 |
+
inputs = current_tokenizer(
|
| 300 |
+
prompt, return_tensors="pt", truncation=True, max_length=512
|
| 301 |
+
)
|
| 302 |
+
inputs = {k: v.to(current_model.device) for k, v in inputs.items()}
|
| 303 |
+
|
| 304 |
+
# Generate single label
|
| 305 |
+
with torch.no_grad():
|
| 306 |
+
outputs = current_model.generate(
|
| 307 |
+
**inputs,
|
| 308 |
+
max_new_tokens=5,
|
| 309 |
+
temperature=0.0,
|
| 310 |
+
do_sample=False,
|
| 311 |
+
pad_token_id=current_tokenizer.eos_token_id,
|
| 312 |
+
use_cache=False
|
| 313 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
+
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 316 |
+
response = response[len(prompt):].strip().lower()
|
|
|
|
| 317 |
|
| 318 |
+
if response.startswith("irrelevant"):
|
| 319 |
+
return "Irrelevant"
|
| 320 |
+
return "Relevant"
|
| 321 |
|
| 322 |
|
| 323 |
def list_checkpoints():
|
|
|
|
| 412 |
|
| 413 |
import time
|
| 414 |
|
| 415 |
+
|
| 416 |
+
def batch_inference(csv_file, checkpoint="latest", batch_size=64):
|
| 417 |
import pandas as pd
|
| 418 |
|
| 419 |
if csv_file is None:
|
| 420 |
raise ValueError("No CSV file uploaded.")
|
| 421 |
|
| 422 |
+
# Gradio File can be path (str) or tempfile object
|
| 423 |
csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
|
| 424 |
if csv_path is None:
|
| 425 |
raise ValueError("Invalid file input from Gradio.")
|
| 426 |
|
| 427 |
df = pd.read_csv(csv_path)
|
|
|
|
| 428 |
if "prompt" not in df.columns:
|
| 429 |
raise ValueError("CSV must have a 'prompt' column")
|
| 430 |
|
|
|
|
| 448 |
correct = 0
|
| 449 |
total = len(prompts)
|
| 450 |
|
| 451 |
+
# Temp output path
|
| 452 |
output_path = "/tmp/batch_inference_results.csv"
|
| 453 |
|
| 454 |
for i in range(0, total, batch_size):
|
|
|
|
| 472 |
)
|
| 473 |
|
| 474 |
batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 475 |
+
|
| 476 |
for prompt, decoded in zip(batch_prompts, batch_decoded):
|
| 477 |
+
response = decoded[len(prompt):].strip().lower()
|
| 478 |
+
if response.startswith("irrelevant"):
|
| 479 |
+
pred = "Irrelevant"
|
| 480 |
+
elif response.startswith("relevant"):
|
| 481 |
+
pred = "Relevant"
|
| 482 |
else:
|
| 483 |
+
pred = decoded.strip()
|
| 484 |
predictions.append(pred)
|
| 485 |
|
| 486 |
+
# Accuracy calculation
|
| 487 |
if "chosen" in df.columns:
|
| 488 |
for j, pred in enumerate(predictions[-len(batch_prompts):]):
|
| 489 |
idx = i + j
|
| 490 |
if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
|
| 491 |
correct += 1
|
| 492 |
|
| 493 |
+
# Save partial results for streaming
|
|
|
|
| 494 |
df_partial = df.copy()
|
| 495 |
df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
|
| 496 |
df_partial.to_csv(output_path, index=False)
|
| 497 |
|
| 498 |
+
# Progress & accuracy stats
|
| 499 |
+
progress = min(i + batch_size, total) / total * 100
|
| 500 |
stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
|
| 501 |
if "chosen" in df.columns:
|
| 502 |
+
stats += f"\nCurrent Accuracy: {correct / len(predictions) * 100:.2f}%"
|
| 503 |
|
| 504 |
+
# Stream update to Gradio
|
| 505 |
yield output_path, stats
|
| 506 |
|
| 507 |
# Final stats
|
| 508 |
final_stats = f"✅ Processed {total} rows"
|
| 509 |
if "chosen" in df.columns:
|
| 510 |
final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
|
|
|
|
| 511 |
|
| 512 |
+
yield output_path, final_stats
|
| 513 |
|
| 514 |
csv_infer_btn = gr.Button("Run Batch Inference")
|
| 515 |
csv_infer_btn.click(
|
requirements.txt
CHANGED
|
@@ -6,8 +6,7 @@ accelerate>=0.25.0
|
|
| 6 |
bitsandbytes>=0.41.0
|
| 7 |
datasets
|
| 8 |
pandas
|
| 9 |
-
torch>=2.0.0
|
| 10 |
scipy
|
| 11 |
beir
|
| 12 |
scikit-learn
|
| 13 |
-
tqdm
|
|
|
|
| 6 |
bitsandbytes>=0.41.0
|
| 7 |
datasets
|
| 8 |
pandas
|
|
|
|
| 9 |
scipy
|
| 10 |
beir
|
| 11 |
scikit-learn
|
| 12 |
+
tqdm
|