Paar, F. (Ferdinand) commited on
Commit
dbbf91e
·
1 Parent(s): 362d318

welcome bert

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. backend/app.py +15 -24
.DS_Store ADDED
Binary file (6.15 kB). View file
 
backend/app.py CHANGED
@@ -1,9 +1,8 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi import Body
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.responses import FileResponse
6
- from transformers import GPT2Tokenizer, GPT2Model, pipeline
7
  import torch as t
8
  import logging
9
 
@@ -21,14 +20,14 @@ app.add_middleware(
21
  )
22
 
23
  # Mount static files (frontend) so that visiting "/" serves index.html
24
- # Note: The directory path "../frontend" works because when running in Docker,
25
  # our working directory is set to /app, and the frontend folder is at /app/frontend.
26
  app.mount("/static", StaticFiles(directory="frontend", html=True), name="static")
27
 
28
- # Load tokenizer and GPT2 model
29
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
30
  try:
31
- model = GPT2Model.from_pretrained('gpt2', output_attentions=True)
32
  except Exception as e:
33
  logger.error(f"Model loading failed: {e}")
34
  raise
@@ -37,8 +36,8 @@ except Exception as e:
37
  async def process_text(text: str = Body(..., embed=True)):
38
  """
39
  Process the input text:
40
- - Tokenizes the text
41
- - Runs the GPT2 model to obtain attentions
42
  - Returns the tokens and attention values (rounded to 2 decimals)
43
  """
44
  try:
@@ -53,9 +52,10 @@ async def process_text(text: str = Body(..., embed=True)):
53
 
54
  decimals = 2
55
  # Convert attention tensors to lists with rounded decimals
56
- attn_series = t.round(t.tensor([
57
- layer_attention.tolist() for layer_attention in attentions
58
- ], dtype=t.double).squeeze(), decimals=decimals).detach().cpu().tolist()
 
59
 
60
  return {
61
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
@@ -65,26 +65,17 @@ async def process_text(text: str = Body(..., embed=True)):
65
  logger.error(f"Error processing text: {e}")
66
  raise HTTPException(status_code=500, detail=str(e))
67
 
68
- # Initialize the text generation pipeline
69
- # This function will be able to generate text
70
- # given an input.
71
- pipe = pipeline("text2text-generation",
72
- model="google/flan-t5-small")
73
 
74
- # Define a function to handle the GET request at `/generate`
75
- # The generate() function is defined as a FastAPI route that takes a
76
- # string parameter called text. The function generates text based on the # input using the pipeline() object, and returns a JSON response
77
- # containing the generated text under the key "output"
78
  @app.get("/generate")
79
  def generate(text: str):
80
  """
81
  Using the text2text-generation pipeline from `transformers`, generate text
82
- from the given input text. The model used is `google/flan-t5-small`, which
83
- can be found [here](<https://huggingface.co/google/flan-t5-small>).
84
  """
85
  # Use the pipeline to generate text from the given input text
86
  output = pipe(text)
87
-
88
  # Return the generated text in a JSON response
89
  return {"output": output[0]["generated_text"]}
90
 
 
1
+ from fastapi import FastAPI, HTTPException, Body
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import FileResponse
5
+ from transformers import BertTokenizer, BertModel, pipeline
6
  import torch as t
7
  import logging
8
 
 
20
  )
21
 
22
  # Mount static files (frontend) so that visiting "/" serves index.html
23
+ # The directory path "../frontend" works because when running in Docker,
24
  # our working directory is set to /app, and the frontend folder is at /app/frontend.
25
  app.mount("/static", StaticFiles(directory="frontend", html=True), name="static")
26
 
27
+ # Load tokenizer and BERT model
28
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
29
  try:
30
+ model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
31
  except Exception as e:
32
  logger.error(f"Model loading failed: {e}")
33
  raise
 
36
  async def process_text(text: str = Body(..., embed=True)):
37
  """
38
  Process the input text:
39
+ - Tokenizes the text using BERT's tokenizer
40
+ - Runs the BERT model to obtain attentions (bidirectional)
41
  - Returns the tokens and attention values (rounded to 2 decimals)
42
  """
43
  try:
 
52
 
53
  decimals = 2
54
  # Convert attention tensors to lists with rounded decimals
55
+ attn_series = t.round(
56
+ t.tensor([layer_attention.tolist() for layer_attention in attentions], dtype=t.double)
57
+ .squeeze(), decimals=decimals
58
+ ).detach().cpu().tolist()
59
 
60
  return {
61
  "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
 
65
  logger.error(f"Error processing text: {e}")
66
  raise HTTPException(status_code=500, detail=str(e))
67
 
68
+ # Initialize the text generation pipeline (unchanged)
69
+ pipe = pipeline("text2text-generation", model="google/flan-t5-small")
 
 
 
70
 
 
 
 
 
71
  @app.get("/generate")
72
  def generate(text: str):
73
  """
74
  Using the text2text-generation pipeline from `transformers`, generate text
75
+ from the given input text. The model used is `google/flan-t5-small`.
 
76
  """
77
  # Use the pipeline to generate text from the given input text
78
  output = pipe(text)
 
79
  # Return the generated text in a JSON response
80
  return {"output": output[0]["generated_text"]}
81